# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import numpy as np import os import random import paddle from paddle.io import Dataset import time import lmdb import cv2 from .imaug import transform, create_operators from ppocr.utils.logging import get_logger logger = get_logger() class LMDBDateSet(Dataset): def __init__(self, config, mode): super(LMDBDateSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] loader_config = config[mode]['loader'] batch_size = loader_config['batch_size_per_card'] data_dir = dataset_config['data_dir'] self.do_shuffle = loader_config['shuffle'] self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir) logger.info("Initialize indexs of datasets:%s" % data_dir) self.data_idx_order_list = self.dataset_traversal() if self.do_shuffle: np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) # # for rec # character = '' # for op in self.ops: # if hasattr(op, 'character'): # character = getattr(op, 'character') # self.info_dict = {'character': character} def load_hierarchical_lmdb_dataset(self, data_dir): lmdb_sets = {} dataset_idx = 0 for dirpath, dirnames, filenames in os.walk(data_dir + '/'): if not dirnames: env = lmdb.open( dirpath, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) txn = env.begin(write=False) num_samples = int(txn.get('num-samples'.encode())) lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \ "txn":txn, "num_samples":num_samples} dataset_idx += 1 return lmdb_sets def dataset_traversal(self): lmdb_num = len(self.lmdb_sets) total_sample_num = 0 for lno in range(lmdb_num): total_sample_num += self.lmdb_sets[lno]['num_samples'] data_idx_order_list = np.zeros((total_sample_num, 2)) beg_idx = 0 for lno in range(lmdb_num): tmp_sample_num = self.lmdb_sets[lno]['num_samples'] end_idx = beg_idx + tmp_sample_num data_idx_order_list[beg_idx:end_idx, 0] = lno data_idx_order_list[beg_idx:end_idx, 1] \ = list(range(tmp_sample_num)) data_idx_order_list[beg_idx:end_idx, 1] += 1 beg_idx = beg_idx + tmp_sample_num return data_idx_order_list def get_img_data(self, value): """get_img_data""" if not value: return None imgdata = np.frombuffer(value, dtype='uint8') if imgdata is None: return None imgori = cv2.imdecode(imgdata, 1) if imgori is None: return None return imgori def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index label = txn.get(label_key) if label is None: return None label = label.decode('utf-8') img_key = 'image-%09d'.encode() % index imgbuf = txn.get(img_key) return imgbuf, label def __getitem__(self, idx): lmdb_idx, file_idx = self.data_idx_order_list[idx] lmdb_idx = int(lmdb_idx) file_idx = int(file_idx) sample_info = self.get_lmdb_sample_info( self.lmdb_sets[lmdb_idx]['txn'], file_idx) if sample_info is None: return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs def __len__(self): return self.data_idx_order_list.shape[0]