diff --git a/ppocr/data/rec/dataset_traversal.py b/ppocr/data/rec/dataset_traversal.py index bbd8090accd7d68bfd11270a4f1356524f22a0b6..839448e4ff3ea36cbb471e7c048ec52cbc8f0cf5 100755 --- a/ppocr/data/rec/dataset_traversal.py +++ b/ppocr/data/rec/dataset_traversal.py @@ -22,7 +22,7 @@ import string import lmdb from ppocr.utils.utility import initial_logger -from tools.infer.utility import get_image_file_list +from ppocr.utils.utility import get_image_file_list logger = initial_logger() from .img_tools import process_image, get_img_data @@ -173,26 +173,27 @@ class SimpleReader(object): img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) norm_img = process_image(img, self.image_shape) yield norm_img - with open(self.label_file_path, "rb") as fin: - label_infor_list = fin.readlines() - img_num = len(label_infor_list) - img_id_list = list(range(img_num)) - random.shuffle(img_id_list) - for img_id in range(process_id, img_num, self.num_workers): - label_infor = label_infor_list[img_id_list[img_id]] - substr = label_infor.decode('utf-8').strip("\n").split("\t") - img_path = self.img_set_dir + "/" + substr[0] - img = cv2.imread(img_path) - if img is None: - logger.info("{} does not exist!".format(img_path)) - continue - label = substr[1] - outs = process_image(img, self.image_shape, label, - self.char_ops, self.loss_type, - self.max_text_length) - if outs is None: - continue - yield outs + else: + with open(self.label_file_path, "rb") as fin: + label_infor_list = fin.readlines() + img_num = len(label_infor_list) + img_id_list = list(range(img_num)) + random.shuffle(img_id_list) + for img_id in range(process_id, img_num, self.num_workers): + label_infor = label_infor_list[img_id_list[img_id]] + substr = label_infor.decode('utf-8').strip("\n").split("\t") + img_path = self.img_set_dir + "/" + substr[0] + img = cv2.imread(img_path) + if img is None: + logger.info("{} does not exist!".format(img_path)) + continue + label = substr[1] + outs = process_image(img, self.image_shape, label, + self.char_ops, self.loss_type, + self.max_text_length) + if outs is None: + continue + yield outs def batch_iter_reader(): batch_outs = [] diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 8f4919f4d41a3674a9997f1831c888a9f01bf5cf..de7799d021ec4838ff04012deb5ed4943421a7df 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -46,7 +46,7 @@ from ppocr.data.reader_main import reader_main from ppocr.utils.save_load import init_model from ppocr.utils.character import CharacterOps from ppocr.utils.utility import create_module -from tools.infer.utility import get_image_file_list +from ppocr.utils.utility import get_image_file_list logger = initial_logger()