diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 1e131e61b3c07c739a0ad3b5c62798ea3c38486b..00f0ffc1fe3387f139789490c4e6557eebb646d0 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -31,6 +31,8 @@ from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif +logger = get_logger() + class TextClassifier(object): def __init__(self, args): @@ -147,5 +149,4 @@ def main(args): if __name__ == "__main__": - logger = get_logger() main(utility.parse_args()) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index a3850028f05a791b5eba3c973a4b20fdb177446e..69db67db3588ec551b9e586ce439ddc96154b899 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -30,6 +30,8 @@ from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.data import create_operators, transform from ppocr.postprocess import build_post_process +logger = get_logger() + class TextDetector(object): def __init__(self, args): @@ -158,9 +160,7 @@ class TextDetector(object): if __name__ == "__main__": args = utility.parse_args() - image_file_list = get_image_file_list(args.image_dir) - logger = get_logger() text_detector = TextDetector(args) count = 0 total_time = 0 diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index a55f671e70f0e28e62429c38af93119f635b2eb0..54dbb03b570b605e5c05150b5fabef4d45d4337a 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -13,12 +13,12 @@ # limitations under the License. import os import sys + __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) import cv2 -import copy import numpy as np import math import time @@ -30,6 +30,8 @@ from ppocr.postprocess import build_post_process from ppocr.utils.logging import get_logger from ppocr.utils.utility import get_image_file_list, check_and_read_gif +logger = get_logger() + class TextRecognizer(object): def __init__(self, args): @@ -80,7 +82,7 @@ class TextRecognizer(object): # rec_res = [] rec_res = [['', 0.0]] * img_num batch_num = self.rec_batch_num - predict_time = 0 + elapse = 0 for beg_img_no in range(0, img_num, batch_num): end_img_no = min(img_num, beg_img_no + batch_num) norm_img_batch = [] @@ -110,7 +112,9 @@ class TextRecognizer(object): output = output_tensor.copy_to_cpu() outputs.append(output) preds = outputs[0] - rec_res = self.postprocess_op(preds) + rec_result = self.postprocess_op(preds) + for rno in range(len(rec_result)): + rec_res[indices[beg_img_no + rno]] = rec_result[rno] elapse = time.time() - starttime return rec_res, elapse @@ -147,5 +151,4 @@ def main(args): if __name__ == "__main__": - logger = get_logger() main(utility.parse_args()) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 647a76b20496335cd059242890f86fffe1e3ac1a..4e81039745ff63985728113ef99fb9e3c54daca5 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -17,20 +17,17 @@ __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '../..'))) -import tools.infer.utility as utility -from ppocr.utils.utility import initial_logger -logger = initial_logger() import cv2 -import tools.infer.predict_det as predict_det -import tools.infer.predict_rec as predict_rec import copy import numpy as np -import math import time -from ppocr.utils.utility import get_image_file_list, check_and_read_gif from PIL import Image +import tools.infer.utility as utility from tools.infer.utility import draw_ocr -from tools.infer.utility import draw_ocr_box_txt +import tools.infer.predict_rec as predict_rec +import tools.infer.predict_det as predict_det +from ppocr.utils.utility import get_image_file_list, check_and_read_gif +from ppocr.utils.logging import get_logger class TextSystem(object): @@ -153,11 +150,7 @@ def main(args): scores = [rec_res[i][1] for i in range(len(rec_res))] draw_img = draw_ocr( - image, - boxes, - txts, - scores, - drop_score=drop_score) + image, boxes, txts, scores, drop_score=drop_score) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) @@ -169,4 +162,5 @@ def main(args): if __name__ == "__main__": + logger = get_logger() main(utility.parse_args()) diff --git a/tools/infer/utility.py b/tools/infer/utility.py index cb2e252292c77eb2c9c5dcc9c781d4011b4e6def..5a524516a28bd3fa1f7a9f2800ebace7b9776088 100755 --- a/tools/infer/utility.py +++ b/tools/infer/utility.py @@ -39,7 +39,8 @@ def parse_args(): parser.add_argument("--image_dir", type=str) parser.add_argument("--det_algorithm", type=str, default='DB') parser.add_argument("--det_model_dir", type=str) - parser.add_argument("--det_max_side_len", type=float, default=960) + parser.add_argument("--det_limit_side_len", type=float, default=960) + parser.add_argument("--det_limit_type", type=str, default='max') # DB parmas parser.add_argument("--det_db_thresh", type=float, default=0.3)