From 0c287c41eaff78e8661aa7a0b39be664f8dae306 Mon Sep 17 00:00:00 2001 From: WenmuZhou Date: Tue, 17 Nov 2020 17:28:28 +0800 Subject: [PATCH] =?UTF-8?q?python=E7=AB=AF=E9=A2=84=E6=B5=8B=E5=AE=8C?= =?UTF-8?q?=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/infer/predict_cls.py | 3 ++- tools/infer/predict_det.py | 4 ++-- tools/infer/predict_rec.py | 11 +++++++---- tools/infer/predict_system.py | 20 +++++++------------- tools/infer/utility.py | 3 ++- 5 files changed, 20 insertions(+), 21 deletions(-) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 1e131e61..00f0ffc1 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -31,6 +31,8 @@ from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif +logger = get_logger() + class TextClassifier(object): def __init__(self, args): @@ -147,5 +149,4 @@ def main(args): if __name__ == "__main__": - logger = get_logger() main(utility.parse_args()) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index a3850028..69db67db 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -30,6 +30,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process +logger = get_logger() + class TextDetector(object): def __init__(self, args): @@ -158,9 +160,7 @@ class TextDetector(object): if __name__ == "__main__": args = utility.parse_args() - image_file_list = get_image_file_list(args.image_dir) - logger = get_logger() text_detector = TextDetector(args) count = 0 total_time = 0 diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index a55f671e..54dbb03b 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -13,12 +13,12 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import cv2 -import copy import numpy as np import math import time @@ -30,6 +30,8 @@ from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif +logger = get_logger() + class TextRecognizer(object): def __init__(self, args): @@ -80,7 +82,7 @@ class TextRecognizer(object): # rec_res = [] rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num - predict_time = 0 + elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] @@ -110,7 +112,9 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - rec_res = self.postprocess_op(preds) + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] elapse = time.time() - starttime return rec_res, elapse @@ -147,5 +151,4 @@ def main(args): if __name__ == "__main__": - logger = get_logger() main(utility.parse_args()) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 647a76b2..4e810397 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -17,20 +17,17 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) -import tools.infer.utility as utility -from ppocr.utils.utility import initial_logger -logger = initial_logger() import cv2 -import tools.infer.predict_det as predict_det -import tools.infer.predict_rec as predict_rec import copy import numpy as np -import math import time -from ppocr.utils.utility import get_image_file_list, check_and_read_gif from PIL import Image +import tools.infer.utility as utility from tools.infer.utility import draw_ocr -from tools.infer.utility import draw_ocr_box_txt +import tools.infer.predict_rec as predict_rec +import tools.infer.predict_det as predict_det +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger class TextSystem(object): @@ -153,11 +150,7 @@ def main(args): scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr( - image, - boxes, - txts, - scores, - drop_score=drop_score) + image, boxes, txts, scores, drop_score=drop_score) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) @@ -169,4 +162,5 @@ def main(args): if __name__ == "__main__": + logger = get_logger() main(utility.parse_args()) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index cb2e2522..5a524516 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -39,7 +39,8 @@ def parse_args(): parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) - parser.add_argument("--det_max_side_len", type=float, default=960) + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default='max') # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3) -- GitLab