diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e1b49809d199096ad06b90c4562aa5dbfa634db1..0c1ea9465c4b61f4cb7106160e31c217a959b3ea 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -88,6 +88,29 @@ class LMDBDataSet(Dataset): if imgori is None: return None return imgori + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(self.__len__())] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) + if sample_info is None: + continue + img, label = sample_info + data = {'image': img, 'label': label} + outs = transform(data, load_data_ops) + ext_data.append(data) + return ext_data def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index @@ -109,6 +132,7 @@ class LMDBDataSet(Dataset): return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__()))