diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index 3a51cefec2f1da2c96cceb6482d8303aa136b78a..295643e401481d30cf433346727f39d4a4c7d2f4 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -40,6 +40,8 @@ class LMDBDataSet(Dataset): if self.do_shuffle: np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", + 1) ratio_list = dataset_config.get("ratio_list", [1.0]) self.need_reset = True in [x < 1 for x in ratio_list] @@ -92,6 +94,32 @@ class LMDBDataSet(Dataset): return None return imgori + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint( + len(self))] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info( + self.lmdb_sets[lmdb_idx]['txn'], file_idx) + if sample_info is None: + continue + img, label = sample_info + data = {'image': img, 'label': label} + data = transform(data, load_data_ops) + if data is None: + continue + ext_data.append(data) + return ext_data + def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index label = txn.get(label_key) @@ -112,6 +140,7 @@ class LMDBDataSet(Dataset): return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__()))