未验证 提交 5cab4c0c 编写于 作者: lgcxy's avatar lgcxy 提交者: GitHub

update lmdb_dateset for ppocrv3 rec

对lmdb_dataset适配ppocrv3的RecConAug数据增强
上级 40db742e
...@@ -88,6 +88,29 @@ class LMDBDataSet(Dataset): ...@@ -88,6 +88,29 @@ class LMDBDataSet(Dataset):
if imgori is None: if imgori is None:
return None return None
return imgori return imgori
def get_ext_data(self):
ext_data_num = 0
for op in self.ops:
if hasattr(op, 'ext_data_num'):
ext_data_num = getattr(op, 'ext_data_num')
break
load_data_ops = self.ops[:self.ext_op_transform_idx]
ext_data = []
while len(ext_data) < ext_data_num:
lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(self.__len__())]
lmdb_idx = int(lmdb_idx)
file_idx = int(file_idx)
sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
file_idx)
if sample_info is None:
continue
img, label = sample_info
data = {'image': img, 'label': label}
outs = transform(data, load_data_ops)
ext_data.append(data)
return ext_data
def get_lmdb_sample_info(self, txn, index): def get_lmdb_sample_info(self, txn, index):
label_key = 'label-%09d'.encode() % index label_key = 'label-%09d'.encode() % index
...@@ -109,6 +132,7 @@ class LMDBDataSet(Dataset): ...@@ -109,6 +132,7 @@ class LMDBDataSet(Dataset):
return self.__getitem__(np.random.randint(self.__len__())) return self.__getitem__(np.random.randint(self.__len__()))
img, label = sample_info img, label = sample_info
data = {'image': img, 'label': label} data = {'image': img, 'label': label}
data['ext_data'] = self.get_ext_data()
outs = transform(data, self.ops) outs = transform(data, self.ops)
if outs is None: if outs is None:
return self.__getitem__(np.random.randint(self.__len__())) return self.__getitem__(np.random.randint(self.__len__()))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册