未验证 提交 2ed52a0e 编写于 作者: littletomatodonkey's avatar littletomatodonkey 提交者: GitHub

fix data to support offline augment (#5232)

* fix data to support offline augment

* fix doc

* fix data
上级 c910bf8d
...@@ -63,6 +63,17 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单 ...@@ -63,6 +63,17 @@ train_data/rec/train/word_002.jpg 用科技让复杂的世界更简单
| ... | ...
``` ```
除上述单张图像为一行格式之外,PaddleOCR也支持对离线增广后的数据进行训练,为了防止相同样本在同一个batch中被多次采样,我们可以将相同标签对应的图片路径写在一行中,以列表的形式给出,在训练中,PaddleOCR会随机选择列表中的一张图片进行训练。对应地,标注文件的格式如下。
```
["11.jpg", "12.jpg"] 简单可依赖
["21.jpg", "22.jpg", "23.jpg"] 用科技让复杂的世界更简单
3.jpg ocr
```
上述示例标注文件中,"11.jpg"和"12.jpg"的标签相同,都是`简单可依赖`,在训练的时候,对于该行标注,会随机选择其中的一张图片进行训练。
- 测试集 - 测试集
同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示: 同训练集类似,测试集也需要提供一个包含所有图片的文件夹(test)和一个rec_gt_test.txt,测试集的结构如下所示:
......
...@@ -69,6 +69,16 @@ class SimpleDataSet(Dataset): ...@@ -69,6 +69,16 @@ class SimpleDataSet(Dataset):
random.shuffle(self.data_lines) random.shuffle(self.data_lines)
return return
def _try_parse_filename_list(self, file_name):
# multiple images -> one gt label
if len(file_name) > 0 and file_name[0] == "[":
try:
info = json.loads(file_name)
file_name = random.choice(info)
except:
pass
return file_name
def get_ext_data(self): def get_ext_data(self):
ext_data_num = 0 ext_data_num = 0
for op in self.ops: for op in self.ops:
...@@ -85,6 +95,7 @@ class SimpleDataSet(Dataset): ...@@ -85,6 +95,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
...@@ -95,7 +106,7 @@ class SimpleDataSet(Dataset): ...@@ -95,7 +106,7 @@ class SimpleDataSet(Dataset):
data['image'] = img data['image'] = img
data = transform(data, load_data_ops) data = transform(data, load_data_ops)
if data is None or data['polys'].shape[1]!=4: if data is None or data['polys'].shape[1] != 4:
continue continue
ext_data.append(data) ext_data.append(data)
return ext_data return ext_data
...@@ -107,6 +118,7 @@ class SimpleDataSet(Dataset): ...@@ -107,6 +118,7 @@ class SimpleDataSet(Dataset):
data_line = data_line.decode('utf-8') data_line = data_line.decode('utf-8')
substr = data_line.strip("\n").split(self.delimiter) substr = data_line.strip("\n").split(self.delimiter)
file_name = substr[0] file_name = substr[0]
file_name = self._try_parse_filename_list(file_name)
label = substr[1] label = substr[1]
img_path = os.path.join(self.data_dir, file_name) img_path = os.path.join(self.data_dir, file_name)
data = {'img_path': img_path, 'label': label} data = {'img_path': img_path, 'label': label}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册