From 2ed52a0e67aed2116430ab16c288ecc6a644a50a Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Tue, 11 Jan 2022 18:40:04 +0800 Subject: [PATCH] fix data to support offline augment (#5232) * fix data to support offline augment * fix doc * fix data --- doc/doc_ch/recognition.md | 11 +++++++++++ ppocr/data/simple_dataset.py | 14 +++++++++++++- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index bb7d0171..51a4b69a 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -63,6 +63,17 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单 | ... ``` +除上述单张图像为一行格式之外,PaddleOCR也支持对离线增广后的数据进行训练,为了防止相同样本在同一个batch中被多次采样,我们可以将相同标签对应的图片路径写在一行中,以列表的形式给出,在训练中,PaddleOCR会随机选择列表中的一张图片进行训练。对应地,标注文件的格式如下。 + +``` +["11.jpg", "12.jpg"] 简单可依赖 +["21.jpg", "22.jpg", "23.jpg"] 用科技让复杂的世界更简单 +3.jpg ocr +``` + +上述示例标注文件中,"11.jpg"和"12.jpg"的标签相同,都是`简单可依赖`,在训练的时候,对于该行标注,会随机选择其中的一张图片进行训练。 + + - 测试集 同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示: diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py index ee8571b8..9f0ce352 100644 --- a/ppocr/data/simple_dataset.py +++ b/ppocr/data/simple_dataset.py @@ -69,6 +69,16 @@ class SimpleDataSet(Dataset): random.shuffle(self.data_lines) return + def _try_parse_filename_list(self, file_name): + # multiple images -> one gt label + if len(file_name) > 0 and file_name[0] == "[": + try: + info = json.loads(file_name) + file_name = random.choice(info) + except: + pass + return file_name + def get_ext_data(self): ext_data_num = 0 for op in self.ops: @@ -85,6 +95,7 @@ class SimpleDataSet(Dataset): data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split(self.delimiter) file_name = substr[0] + file_name = self._try_parse_filename_list(file_name) label = substr[1] img_path = os.path.join(self.data_dir, file_name) data = {'img_path': img_path, 'label': label} @@ -95,7 +106,7 @@ class SimpleDataSet(Dataset): data['image'] = img data = transform(data, load_data_ops) - if data is None or data['polys'].shape[1]!=4: + if data is None or data['polys'].shape[1] != 4: continue ext_data.append(data) return ext_data @@ -107,6 +118,7 @@ class SimpleDataSet(Dataset): data_line = data_line.decode('utf-8') substr = data_line.strip("\n").split(self.delimiter) file_name = substr[0] + file_name = self._try_parse_filename_list(file_name) label = substr[1] img_path = os.path.join(self.data_dir, file_name) data = {'img_path': img_path, 'label': label} -- GitLab