From 50bcec466186d43688e085bc20039ba79aec876d Mon Sep 17 00:00:00 2001 From: Jethong <1147925384@qq.com> Date: Tue, 13 Apr 2021 13:50:44 +0800 Subject: [PATCH] fix data input format --- configs/e2e/e2e_r50_vd_pg.yml | 5 +- ppocr/data/imaug/label_ops.py | 44 +++---- ppocr/data/imaug/pg_process.py | 8 +- ppocr/data/pgnet_dataset.py | 113 ++++-------------- ppocr/metrics/e2e_metric.py | 6 +- ppocr/utils/e2e_metric/Deteval.py | 2 +- .../utils/e2e_utils/extract_textpoint_fast.py | 1 + ppocr/utils/e2e_utils/pgnet_pp_utils.py | 4 +- 8 files changed, 58 insertions(+), 125 deletions(-) diff --git a/configs/e2e/e2e_r50_vd_pg.yml b/configs/e2e/e2e_r50_vd_pg.yml index e4d868f9..cd81ffbf 100644 --- a/configs/e2e/e2e_r50_vd_pg.yml +++ b/configs/e2e/e2e_r50_vd_pg.yml @@ -69,6 +69,7 @@ Metric: Train: dataset: name: PGDataSet + data_dir: ./train_data/ label_file_list: [.././train_data/total_text/train/] ratio_list: [1.0] data_format: icdar #two data format: icdar/textnet @@ -76,6 +77,7 @@ Train: - DecodeImage: # load image img_mode: BGR channel_first: False + - E2ELabelEncode: - PGProcessTrain: batch_size: 14 # same as loader: batch_size_per_card min_crop_size: 24 @@ -98,7 +100,6 @@ Eval: - DecodeImage: # load image img_mode: RGB channel_first: False - - E2ELabelEncode: - E2EResizeForTest: max_side_len: 768 - NormalizeImage: @@ -108,7 +109,7 @@ Eval: order: 'hwc' - ToCHWImage: - KeepKeys: - keep_keys: [ 'image', 'shape', 'polys', 'strs', 'tags', 'img_id'] + keep_keys: [ 'image', 'shape', 'img_id'] loader: shuffle: False drop_last: False diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index cbb11009..44c455d8 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -187,29 +187,31 @@ class CTCLabelEncode(BaseRecLabelEncode): return dict_character -class E2ELabelEncode(BaseRecLabelEncode): - def __init__(self, - max_text_length, - character_dict_path=None, - character_type='EN', - use_space_char=False, - **kwargs): - super(E2ELabelEncode, - self).__init__(max_text_length, character_dict_path, - character_type, use_space_char) - self.pad_num = len(self.dict) # the length to pad +class E2ELabelEncode(object): + def __init__(self, **kwargs): + pass def __call__(self, data): - texts = data['strs'] - temp_texts = [] - for text in texts: - text = text.lower() - text = self.encode(text) - if text is None: - return None - text = text + [self.pad_num] * (self.max_text_len - len(text)) - temp_texts.append(text) - data['strs'] = np.array(temp_texts) + import json + label = data['label'] + label = json.loads(label) + nBox = len(label) + boxes, txts, txt_tags = [], [], [] + for bno in range(0, nBox): + box = label[bno]['points'] + txt = label[bno]['transcription'] + boxes.append(box) + txts.append(txt) + if txt in ['*', '###']: + txt_tags.append(True) + else: + txt_tags.append(False) + boxes = np.array(boxes, dtype=np.float32) + txt_tags = np.array(txt_tags, dtype=np.bool) + + data['polys'] = boxes + data['texts'] = txts + data['ignore_tags'] = txt_tags return data diff --git a/ppocr/data/imaug/pg_process.py b/ppocr/data/imaug/pg_process.py index 0c9439d7..53031064 100644 --- a/ppocr/data/imaug/pg_process.py +++ b/ppocr/data/imaug/pg_process.py @@ -88,7 +88,7 @@ class PGProcessTrain(object): return min_area_quad - def check_and_validate_polys(self, polys, tags, xxx_todo_changeme): + def check_and_validate_polys(self, polys, tags, im_size): """ check so that the text poly is in the same direction, and also filter some invalid polygons @@ -96,7 +96,7 @@ class PGProcessTrain(object): :param tags: :return: """ - (h, w) = xxx_todo_changeme + (h, w) = im_size if polys.shape[0] == 0: return polys, np.array([]), np.array([]) polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1) @@ -750,8 +750,8 @@ class PGProcessTrain(object): input_size = 512 im = data['image'] text_polys = data['polys'] - text_tags = data['tags'] - text_strs = data['strs'] + text_tags = data['ignore_tags'] + text_strs = data['texts'] h, w, _ = im.shape text_polys, text_tags, hv_tags = self.check_and_validate_polys( text_polys, text_tags, (h, w)) diff --git a/ppocr/data/pgnet_dataset.py b/ppocr/data/pgnet_dataset.py index 543dbe79..c6bc694f 100644 --- a/ppocr/data/pgnet_dataset.py +++ b/ppocr/data/pgnet_dataset.py @@ -29,20 +29,20 @@ class PGDataSet(Dataset): dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] + self.delimiter = dataset_config.get('delimiter', '\t') label_file_list = dataset_config.pop('label_file_list') data_source_num = len(label_file_list) ratio_list = dataset_config.get("ratio_list", [1.0]) if isinstance(ratio_list, (float, int)): ratio_list = [float(ratio_list)] * int(data_source_num) - self.data_format = dataset_config.get('data_format', 'icdar') assert len( ratio_list ) == data_source_num, "The length of ratio_list should be the same as the file_list." + self.data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] logger.info("Initialize indexs of datasets:%s" % label_file_list) - self.data_lines = self.get_image_info_list(label_file_list, ratio_list, - self.data_format) + self.data_lines = self.get_image_info_list(label_file_list, ratio_list) self.data_idx_order_list = list(range(len(self.data_lines))) if mode.lower() == "train": self.shuffle_data_random() @@ -55,108 +55,37 @@ class PGDataSet(Dataset): random.shuffle(self.data_lines) return - def extract_polys(self, poly_txt_path): - """ - Read text_polys, txt_tags, txts from give txt file. - """ - text_polys, txt_tags, txts = [], [], [] - with open(poly_txt_path) as f: - for line in f.readlines(): - poly_str, txt = line.strip().split('\t') - poly = list(map(float, poly_str.split(','))) - text_polys.append( - np.array( - poly, dtype=np.float32).reshape(-1, 2)) - txts.append(txt) - txt_tags.append(txt == '###') - - return np.array(list(map(np.array, text_polys))), \ - np.array(txt_tags, dtype=np.bool), txts - - def extract_info_textnet(self, im_fn, img_dir=''): - """ - Extract information from line in textnet format. - """ - info_list = im_fn.split('\t') - img_path = '' - for ext in [ - 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff', 'gif', 'JPG' - ]: - if os.path.exists(os.path.join(img_dir, info_list[0] + "." + ext)): - img_path = os.path.join(img_dir, info_list[0] + "." + ext) - break - - if img_path == '': - print('Image {0} NOT found in {1}, and it will be ignored.'.format( - info_list[0], img_dir)) - - nBox = (len(info_list) - 1) // 9 - wordBBs, txts, txt_tags = [], [], [] - for n in range(0, nBox): - wordBB = list(map(float, info_list[n * 9 + 1:(n + 1) * 9])) - txt = info_list[(n + 1) * 9] - wordBBs.append([[wordBB[0], wordBB[1]], [wordBB[2], wordBB[3]], - [wordBB[4], wordBB[5]], [wordBB[6], wordBB[7]]]) - txts.append(txt) - if txt == '###': - txt_tags.append(True) - else: - txt_tags.append(False) - return img_path, np.array(wordBBs, dtype=np.float32), txt_tags, txts - - def get_image_info_list(self, file_list, ratio_list, data_format='textnet'): + def get_image_info_list(self, file_list, ratio_list): if isinstance(file_list, str): file_list = [file_list] data_lines = [] - for idx, data_source in enumerate(file_list): - image_files = [] - if data_format == 'icdar': - image_files = [(data_source, x) for x in - os.listdir(os.path.join(data_source, 'rgb')) - if x.split('.')[-1] in [ - 'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', - 'tiff', 'gif', 'JPG' - ]] - elif data_format == 'textnet': - with open(data_source) as f: - image_files = [(data_source, x.strip()) - for x in f.readlines()] - else: - print("Unrecognized data format...") - exit(-1) - random.seed(self.seed) - image_files = random.sample( - image_files, round(len(image_files) * ratio_list[idx])) - data_lines.extend(image_files) + for idx, file in enumerate(file_list): + with open(file, "rb") as f: + lines = f.readlines() + if self.mode == "train" or ratio_list[idx] < 1.0: + random.seed(self.seed) + lines = random.sample(lines, + round(len(lines) * ratio_list[idx])) + data_lines.extend(lines) return data_lines def __getitem__(self, idx): file_idx = self.data_idx_order_list[idx] - data_path, data_line = self.data_lines[file_idx] + data_line = self.data_lines[file_idx] try: - if self.data_format == 'icdar': - im_path = os.path.join(data_path, 'rgb', data_line) - poly_path = os.path.join(data_path, 'poly', - data_line.split('.')[0] + '.txt') - text_polys, text_tags, text_strs = self.extract_polys(poly_path) - else: - image_dir = os.path.join(os.path.dirname(data_path), 'image') - im_path, text_polys, text_tags, text_strs = self.extract_info_textnet( - data_line, image_dir) + data_line = data_line.decode('utf-8') + substr = data_line.strip("\n").split(self.delimiter) + file_name = substr[0] + label = substr[1] + img_path = os.path.join(self.data_dir, file_name) img_id = int(data_line.split(".")[0][3:]) - - data = { - 'img_path': im_path, - 'polys': text_polys, - 'tags': text_tags, - 'strs': text_strs, - 'img_id': img_id - } + data = {'img_path': img_path, 'label': label, 'img_id': img_id} + if not os.path.exists(img_path): + raise Exception("{} does not exist!".format(img_path)) with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img outs = transform(data, self.ops) - except Exception as e: self.logger.error( "When parsing line {}, error happened with msg: {}".format( diff --git a/ppocr/metrics/e2e_metric.py b/ppocr/metrics/e2e_metric.py index 8a604192..525aa003 100644 --- a/ppocr/metrics/e2e_metric.py +++ b/ppocr/metrics/e2e_metric.py @@ -35,11 +35,11 @@ class E2EMetric(object): self.reset() def __call__(self, preds, batch, **kwargs): - img_id = batch[5][0] + img_id = batch[2][0] e2e_info_list = [{ 'points': det_polyon, - 'text': pred_str - } for det_polyon, pred_str in zip(preds['points'], preds['strs'])] + 'texts': pred_str + } for det_polyon, pred_str in zip(preds['points'], preds['texts'])] result = get_socre(self.gt_mat_dir, img_id, e2e_info_list) self.results.append(result) diff --git a/ppocr/utils/e2e_metric/Deteval.py b/ppocr/utils/e2e_metric/Deteval.py index e30a498e..2aa09304 100755 --- a/ppocr/utils/e2e_metric/Deteval.py +++ b/ppocr/utils/e2e_metric/Deteval.py @@ -26,7 +26,7 @@ def get_socre(gt_dir, img_id, pred_dict): n = len(pred_dict) for i in range(n): points = pred_dict[i]['points'] - text = pred_dict[i]['text'] + text = pred_dict[i]['texts'] point = ",".join(map(str, points.reshape(-1, ))) det.append([point, text]) return det diff --git a/ppocr/utils/e2e_utils/extract_textpoint_fast.py b/ppocr/utils/e2e_utils/extract_textpoint_fast.py index 787cd301..06a68d3a 100644 --- a/ppocr/utils/e2e_utils/extract_textpoint_fast.py +++ b/ppocr/utils/e2e_utils/extract_textpoint_fast.py @@ -21,6 +21,7 @@ import math import numpy as np from itertools import groupby +from cv2.ximgproc import thinning as thin from skimage.morphology._skeletonize import thin diff --git a/ppocr/utils/e2e_utils/pgnet_pp_utils.py b/ppocr/utils/e2e_utils/pgnet_pp_utils.py index 64bfd372..6394d787 100644 --- a/ppocr/utils/e2e_utils/pgnet_pp_utils.py +++ b/ppocr/utils/e2e_utils/pgnet_pp_utils.py @@ -64,7 +64,7 @@ class PGNet_PostProcess(object): src_w, src_h, self.valid_set) data = { 'points': poly_list, - 'strs': keep_str_list, + 'texts': keep_str_list, } return data @@ -176,6 +176,6 @@ class PGNet_PostProcess(object): exit(-1) data = { 'points': poly_list, - 'strs': keep_str_list, + 'texts': keep_str_list, } return data -- GitLab