diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e1b49809d199096ad06b90c4562aa5dbfa634db1..2b1ccaddcea437acb3901d6b0391dd0c7b2954b7 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -37,6 +37,8 @@ class LMDBDataSet(Dataset): if self.do_shuffle: np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", + 2) ratio_list = dataset_config.get("ratio_list", [1.0]) self.need_reset = True in [x < 1 for x in ratio_list] @@ -88,6 +90,29 @@ class LMDBDataSet(Dataset): if imgori is None: return None return imgori + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(self.__len__())] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) + if sample_info is None: + continue + img, label = sample_info + data = {'image': img, 'label': label} + outs = transform(data, load_data_ops) + ext_data.append(data) + return ext_data def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index @@ -109,6 +134,7 @@ class LMDBDataSet(Dataset): return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__()))