# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import os from paddle.io import Dataset import lmdb import cv2 import string import six import pickle from PIL import Image from .imaug import transform, create_operators class LMDBDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): super(LMDBDataSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) logger.info("Initialize indexs of datasets:%s" % data_dir) self.data_idx_order_list = self.dataset_traversal() if self.do_shuffle: np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", 1) ratio_list = dataset_config.get("ratio_list", [1.0]) self.need_reset = True in [x < 1 for x in ratio_list] def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 for dirpath, dirnames, filenames in os.walk(data_dir + '/'): if not dirnames: env = lmdb.open( dirpath, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) txn = env.begin(write=False) num_samples = int(txn.get('num-samples'.encode())) lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \ "txn":txn, "num_samples":num_samples} dataset_idx += 1 return lmdb_sets def dataset_traversal(self): lmdb_num = len(self.lmdb_sets) total_sample_num = 0 for lno in range(lmdb_num): total_sample_num += self.lmdb_sets[lno]['num_samples'] data_idx_order_list = np.zeros((total_sample_num, 2)) beg_idx = 0 for lno in range(lmdb_num): tmp_sample_num = self.lmdb_sets[lno]['num_samples'] end_idx = beg_idx + tmp_sample_num data_idx_order_list[beg_idx:end_idx, 0] = lno data_idx_order_list[beg_idx:end_idx, 1] \ = list(range(tmp_sample_num)) data_idx_order_list[beg_idx:end_idx, 1] += 1 beg_idx = beg_idx + tmp_sample_num return data_idx_order_list def get_img_data(self, value): """get_img_data""" if not value: return None imgdata = np.frombuffer(value, dtype='uint8') if imgdata is None: return None imgori = cv2.imdecode(imgdata, 1) if imgori is None: return None return imgori def get_ext_data(self): ext_data_num = 0 for op in self.ops: if hasattr(op, 'ext_data_num'): ext_data_num = getattr(op, 'ext_data_num') break load_data_ops = self.ops[:self.ext_op_transform_idx] ext_data = [] while len(ext_data) < ext_data_num: lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint( len(self))] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) sample_info = self.get_lmdb_sample_info( self.lmdb_sets[lmdb_idx]['txn'], file_idx) if sample_info is None: continue img, label = sample_info data = {'image': img, 'label': label} data = transform(data, load_data_ops) if data is None: continue ext_data.append(data) return ext_data def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index label = txn.get(label_key) if label is None: return None label = label.decode('utf-8') img_key = 'image-%09d'.encode() % index imgbuf = txn.get(img_key) return imgbuf, label def __getitem__(self, idx): lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], file_idx) if sample_info is None: return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs def __len__(self): return self.data_idx_order_list.shape[0] class LMDBDataSetSR(LMDBDataSet): def buf2PIL(self, txn, key, type='RGB'): imgbuf = txn.get(key) buf = six.BytesIO() buf.write(imgbuf) buf.seek(0) im = Image.open(buf).convert(type) return im def str_filt(self, str_, voc_type): alpha_dict = { 'digit': string.digits, 'lower': string.digits + string.ascii_lowercase, 'upper': string.digits + string.ascii_letters, 'all': string.digits + string.ascii_letters + string.punctuation } if voc_type == 'lower': str_ = str_.lower() for char in str_: if char not in alpha_dict[voc_type]: str_ = str_.replace(char, '') return str_ def get_lmdb_sample_info(self, txn, index): self.voc_type = 'upper' self.max_len = 100 self.test = False label_key = b'label-%09d' % index word = str(txn.get(label_key).decode()) img_HR_key = b'image_hr-%09d' % index # 128*32 img_lr_key = b'image_lr-%09d' % index # 64*16 try: img_HR = self.buf2PIL(txn, img_HR_key, 'RGB') img_lr = self.buf2PIL(txn, img_lr_key, 'RGB') except IOError or len(word) > self.max_len: return self[index + 1] label_str = self.str_filt(word, self.voc_type) return img_HR, img_lr, label_str def __getitem__(self, idx): lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], file_idx) if sample_info is None: return self.__getitem__(np.random.randint(self.__len__())) img_HR, img_lr, label_str = sample_info data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str} outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs class LMDBDataSetTableMaster(LMDBDataSet): def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 env = lmdb.open( data_dir, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) txn = env.begin(write=False) num_samples = int(pickle.loads(txn.get(b"__len__"))) lmdb_sets[dataset_idx] = {"dirpath":data_dir, "env":env, \ "txn":txn, "num_samples":num_samples} return lmdb_sets def get_img_data(self, value): """get_img_data""" if not value: return None imgdata = np.frombuffer(value, dtype='uint8') if imgdata is None: return None imgori = cv2.imdecode(imgdata, 1) if imgori is None: return None return imgori def get_lmdb_sample_info(self, txn, index): def convert_bbox(bbox_str_list): bbox_list = [] for bbox_str in bbox_str_list: bbox_list.append(int(bbox_str)) return bbox_list try: data = pickle.loads(txn.get(str(index).encode('utf8'))) except: return None # img_name, img, info_lines file_name = data[0] bytes = data[1] info_lines = data[2] # raw data from TableMASTER annotation file. # parse info_lines raw_data = info_lines.strip().split('\n') raw_name, text = raw_data[0], raw_data[ 1] # don't filter the samples's length over max_seq_len. text = text.split(',') bbox_str_list = raw_data[2:] bbox_split = ',' bboxes = [{ 'bbox': convert_bbox(bsl.strip().split(bbox_split)), 'tokens': ['1', '2'] } for bsl in bbox_str_list] # advance parse bbox # import pdb;pdb.set_trace() line_info = {} line_info['file_name'] = file_name line_info['structure'] = text line_info['cells'] = bboxes line_info['image'] = bytes return line_info def __getitem__(self, idx): lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) data = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], file_idx) if data is None: return self.__getitem__(np.random.randint(self.__len__())) outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs def __len__(self): return self.data_idx_order_list.shape[0]