diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index 817b8fdb1af1809abf779189a556909b72ef0f49..097da768aa4da2eb023c4a346fc30f0e704ab953 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -32,12 +32,10 @@ class SimpleDataSet(Dataset): self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') data_source_num = len(label_file_list) - if data_source_num == 1: - ratio_list = [1.0] - else: - ratio_list = dataset_config.pop('ratio_list') + ratio_list = dataset_config.get("ratio_list", [1.0]) + if isinstance(ratio_list, (float, int)): + ratio_list = [float(ratio_list)] * len(data_source_num) - assert sum(ratio_list) == 1, "The sum of the ratio_list should be 1." assert len( ratio_list ) == data_source_num, "The length of ratio_list should be the same as the file_list." @@ -45,62 +43,32 @@ class SimpleDataSet(Dataset): self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) - self.data_lines_list, data_num_list = self.get_image_info_list( - label_file_list) - self.data_idx_order_list = self.dataset_traversal( - data_num_list, ratio_list, batch_size) - self.shuffle_data_random() - + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) + self.data_idx_order_list = list(range(len(self.data_lines))) + if mode.lower() == "train": + self.shuffle_data_random() self.ops = create_operators(dataset_config['transforms'], global_config) - def get_image_info_list(self, file_list): + def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] - data_lines_list = [] - data_num_list = [] - for file in file_list: + data_lines = [] + for idx, file in enumerate(file_list): with open(file, "rb") as f: lines = f.readlines() - data_lines_list.append(lines) - data_num_list.append(len(lines)) - return data_lines_list, data_num_list - - def dataset_traversal(self, data_num_list, ratio_list, batch_size): - select_num_list = [] - dataset_num = len(data_num_list) - for dno in range(dataset_num): - select_num = round(batch_size * ratio_list[dno]) - select_num = max(select_num, 1) - select_num_list.append(select_num) - data_idx_order_list = [] - cur_index_sets = [0] * dataset_num - while True: - finish_read_num = 0 - for dataset_idx in range(dataset_num): - cur_index = cur_index_sets[dataset_idx] - if cur_index >= data_num_list[dataset_idx]: - finish_read_num += 1 - else: - select_num = select_num_list[dataset_idx] - for sno in range(select_num): - cur_index = cur_index_sets[dataset_idx] - if cur_index >= data_num_list[dataset_idx]: - break - data_idx_order_list.append((dataset_idx, cur_index)) - cur_index_sets[dataset_idx] += 1 - if finish_read_num == dataset_num: - break - return data_idx_order_list + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) + return data_lines def shuffle_data_random(self): if self.do_shuffle: - for dno in range(len(self.data_lines_list)): - random.shuffle(self.data_lines_list[dno]) + random.shuffle(self.data_lines) return def __getitem__(self, idx): - dataset_idx, file_idx = self.data_idx_order_list[idx] - data_line = self.data_lines_list[dataset_idx][file_idx] + file_idx = self.data_idx_order_list[idx] + data_line = self.data_lines[file_idx] try: data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split(self.delimiter) diff --git a/tools/infer/predict_cls.py b/tools/infer/predict_cls.py index 9ec03396f95bd24704be014633916631ff98e627..420213ee5a6fce1f11c72b960d7e90344dd295ee 100755 --- a/tools/infer/predict_cls.py +++ b/tools/infer/predict_cls.py @@ -23,7 +23,7 @@ import copy import numpy as np import math import time - +import traceback import paddle.fluid as fluid import tools.infer.utility as utility @@ -106,10 +106,10 @@ class TextClassifier(object): norm_img_batch = fluid.core.PaddleTensor(norm_img_batch) self.predictor.run([norm_img_batch]) prob_out = self.output_tensors[0].copy_to_cpu() - cls_res = self.postprocess_op(prob_out) + cls_result = self.postprocess_op(prob_out) elapse += time.time() - starttime - for rno in range(len(cls_res)): - label, score = cls_res[rno] + for rno in range(len(cls_result)): + label, score = cls_result[rno] cls_res[indices[beg_img_no + rno]] = [label, score] if '180' in label and score > self.cls_thresh: img_list[indices[beg_img_no + rno]] = cv2.rotate( @@ -133,8 +133,8 @@ def main(args): img_list.append(img) try: img_list, cls_res, predict_time = text_classifier(img_list) - except Exception as e: - print(e) + except: + logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" @@ -143,10 +143,10 @@ def main(args): "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): - print("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], cls_res[ ino])) - print("Total predict time for {} images, cost: {:.3f}".format( + logger.info("Total predict time for {} images, cost: {:.3f}".format( len(img_list), predict_time)) - if __name__ == "__main__": - main(utility.parse_args()) +if __name__ == "__main__": + main(utility.parse_args()) diff --git a/tools/infer/predict_det.py b/tools/infer/predict_det.py index 4b4825a66a145faf78d96446604329730a453381..5be27339dbae07c8d99fe442f18e64288d831f79 100755 --- a/tools/infer/predict_det.py +++ b/tools/infer/predict_det.py @@ -178,11 +178,12 @@ if __name__ == "__main__": if count > 0: total_time += elapse count += 1 - print("Predict time of {}: {}".format(image_file, elapse)) + logger.info("Predict time of {}: {}".format(image_file, elapse)) src_im = utility.draw_text_det_res(dt_boxes, image_file) img_name_pure = os.path.split(image_file)[-1] img_path = os.path.join(draw_img_save, "det_res_{}".format(img_name_pure)) cv2.imwrite(img_path, src_im) + logger.info("The visualized image saved in {}".format(img_path)) if count > 1: - print("Avg Time:", total_time / (count - 1)) + logger.info("Avg Time:", total_time / (count - 1)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index c1f20ef3c6b42772d47665032504b1fae039cbcd..c615fa0d36e9179d0e11d7e5588d223361aab349 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -22,7 +22,7 @@ import cv2 import numpy as np import math import time - +import traceback import paddle.fluid as fluid import tools.infer.utility as utility @@ -135,8 +135,8 @@ def main(args): img_list.append(img) try: rec_res, predict_time = text_recognizer(img_list) - except Exception as e: - print(e) + except: + logger.info(traceback.format_exc()) logger.info( "ERROR!!!! \n" "Please read the FAQ:https://github.com/PaddlePaddle/PaddleOCR#faq \n" @@ -145,9 +145,9 @@ def main(args): "Please set --rec_image_shape='3,32,100' and --rec_char_type='en' ") exit() for ino in range(len(img_list)): - print("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ + logger.info("Predicts of {}:{}".format(valid_image_file_list[ino], rec_res[ ino])) - print("Total predict time for {} images, cost: {:.3f}".format( + logger.info("Total predict time for {} images, cost: {:.3f}".format( len(img_list), predict_time)) diff --git a/tools/infer/predict_system.py b/tools/infer/predict_system.py index 4e81039745ff63985728113ef99fb9e3c54daca5..3facada52fbdff14044783f583739bb6c36d2094 100755 --- a/tools/infer/predict_system.py +++ b/tools/infer/predict_system.py @@ -23,17 +23,21 @@ import numpy as np import time from PIL import Image import tools.infer.utility as utility -from tools.infer.utility import draw_ocr import tools.infer.predict_rec as predict_rec import tools.infer.predict_det as predict_det +import tools.infer.predict_cls as predict_cls from ppocr.utils.utility import get_image_file_list, check_and_read_gif from ppocr.utils.logging import get_logger +from tools.infer.utility import draw_ocr_box_txt class TextSystem(object): def __init__(self, args): self.text_detector = predict_det.TextDetector(args) self.text_recognizer = predict_rec.TextRecognizer(args) + self.use_angle_cls = args.use_angle_cls + if self.use_angle_cls: + self.text_classifier = predict_cls.TextClassifier(args) def get_rotate_crop_image(self, img, points): ''' @@ -72,12 +76,13 @@ class TextSystem(object): bbox_num = len(img_crop_list) for bno in range(bbox_num): cv2.imwrite("./output/img_crop_%d.jpg" % bno, img_crop_list[bno]) - print(bno, rec_res[bno]) + logger.info(bno, rec_res[bno]) def __call__(self, img): ori_im = img.copy() dt_boxes, elapse = self.text_detector(img) - print("dt_boxes num : {}, elapse : {}".format(len(dt_boxes), elapse)) + logger.info("dt_boxes num : {}, elapse : {}".format( + len(dt_boxes), elapse)) if dt_boxes is None: return None, None img_crop_list = [] @@ -88,8 +93,15 @@ class TextSystem(object): tmp_box = copy.deepcopy(dt_boxes[bno]) img_crop = self.get_rotate_crop_image(ori_im, tmp_box) img_crop_list.append(img_crop) + if self.use_angle_cls: + img_crop_list, angle_list, elapse = self.text_classifier( + img_crop_list) + logger.info("cls num : {}, elapse : {}".format( + len(img_crop_list), elapse)) + rec_res, elapse = self.text_recognizer(img_crop_list) - print("rec_res num : {}, elapse : {}".format(len(rec_res), elapse)) + logger.info("rec_res num : {}, elapse : {}".format( + len(rec_res), elapse)) # self.print_draw_crop_rec_res(img_crop_list, rec_res) return dt_boxes, rec_res @@ -119,7 +131,8 @@ def main(args): image_file_list = get_image_file_list(args.image_dir) text_sys = TextSystem(args) is_visualize = True - tackle_img_num = 0 + font_path = args.vis_font_path + drop_score = args.drop_score for image_file in image_file_list: img, flag = check_and_read_gif(image_file) if not flag: @@ -128,20 +141,16 @@ def main(args): logger.info("error in loading image:{}".format(image_file)) continue starttime = time.time() - tackle_img_num += 1 - if not args.use_gpu and args.enable_mkldnn and tackle_img_num % 30 == 0: - text_sys = TextSystem(args) dt_boxes, rec_res = text_sys(img) elapse = time.time() - starttime - print("Predict time of %s: %.3fs" % (image_file, elapse)) + logger.info("Predict time of %s: %.3fs" % (image_file, elapse)) - drop_score = 0.5 dt_num = len(dt_boxes) for dno in range(dt_num): text, score = rec_res[dno] if score >= drop_score: text_str = "%s, %.3f" % (text, score) - print(text_str) + logger.info(text_str) if is_visualize: image = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) @@ -149,15 +158,20 @@ def main(args): txts = [rec_res[i][0] for i in range(len(rec_res))] scores = [rec_res[i][1] for i in range(len(rec_res))] - draw_img = draw_ocr( - image, boxes, txts, scores, drop_score=drop_score) + draw_img = draw_ocr_box_txt( + image, + boxes, + txts, + scores, + drop_score=drop_score, + font_path=font_path) draw_img_save = "./inference_results/" if not os.path.exists(draw_img_save): os.makedirs(draw_img_save) cv2.imwrite( os.path.join(draw_img_save, os.path.basename(image_file)), draw_img[:, :, ::-1]) - print("The visualized image saved in {}".format( + logger.info("The visualized image saved in {}".format( os.path.join(draw_img_save, os.path.basename(image_file))))