# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import numpy as np import os import lmdb import random import signal import paddle from paddle.io import Dataset, DataLoader, DistributedBatchSampler, BatchSampler from .imaug import transform, create_operators from ppocr.utils.logging import get_logger def term_mp(sig_num, frame): """ kill all child processes """ pid = os.getpid() pgid = os.getpgid(os.getpid()) print("main proc {} exit, kill process group " "{}".format(pid, pgid)) os.killpg(pgid, signal.SIGKILL) signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGTERM, term_mp) class ModeException(Exception): """ ModeException """ def __init__(self, message='', mode=''): message += "\nOnly the following 3 modes are supported: " \ "train, valid, test. Given mode is {}".format(mode) super(ModeException, self).__init__(message) class SampleNumException(Exception): """ SampleNumException """ def __init__(self, message='', sample_num=0, batch_size=1): message += "\nError: The number of the whole data ({}) " \ "is smaller than the batch_size ({}), and drop_last " \ "is turnning on, so nothing will feed in program, " \ "Terminated now. Please reset batch_size to a smaller " \ "number or feed more data!".format(sample_num, batch_size) super(SampleNumException, self).__init__(message) def get_file_list(file_list, data_dir, delimiter='\t'): """ read label list from file and shuffle the list Args: params(dict): """ if isinstance(file_list, str): file_list = [file_list] data_source_list = [] for file in file_list: with open(file) as f: full_lines = [line.strip() for line in f] for line in full_lines: try: img_path, label = line.split(delimiter) except: logger = get_logger() logger.warning('label error in {}'.format(line)) img_path = os.path.join(data_dir, img_path) data = {'img_path': img_path, 'label': label} data_source_list.append(data) return data_source_list class LMDBDateSet(Dataset): def __init__(self, config, global_config): super(LMDBDateSet, self).__init__() self.data_list = self.load_lmdb_dataset( config['file_list'], global_config['max_text_length']) random.shuffle(self.data_list) self.ops = create_operators(config['transforms'], global_config) # for rec character = '' for op in self.ops: if hasattr(op, 'character'): character = getattr(op, 'character') self.info_dict = {'character': character} def load_lmdb_dataset(self, data_dir, max_text_length): self.env = lmdb.open( data_dir, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) if not self.env: print('cannot create lmdb from %s' % (data_dir)) exit(0) filtered_index_list = [] with self.env.begin(write=False) as txn: nSamples = int(txn.get('num-samples'.encode())) self.nSamples = nSamples for index in range(self.nSamples): index += 1 # lmdb starts with 1 label_key = 'label-%09d'.encode() % index label = txn.get(label_key).decode('utf-8') if len(label) > max_text_length: # print(f'The length of the label is longer than max_length: length # {len(label)}, {label} in dataset {self.root}') continue # By default, images containing characters which are not in opt.character are filtered. # You can add [UNK] token to `opt.character` in utils.py instead of this filtering. filtered_index_list.append(index) return filtered_index_list def print_lmdb_sets_info(self, lmdb_sets): lmdb_info_strs = [] for dataset_idx in range(len(lmdb_sets)): tmp_str = " %s:%d," % (lmdb_sets[dataset_idx]['dirpath'], lmdb_sets[dataset_idx]['num_samples']) lmdb_info_strs.append(tmp_str) lmdb_info_strs = ''.join(lmdb_info_strs) logger = get_logger() logger.info("DataSummary:" + lmdb_info_strs) return def __getitem__(self, idx): idx = self.data_list[idx] with self.env.begin(write=False) as txn: label_key = 'label-%09d'.encode() % idx label = txn.get(label_key) if label is not None: label = label.decode('utf-8') img_key = 'image-%09d'.encode() % idx imgbuf = txn.get(img_key) data = {'image': imgbuf, 'label': label} outs = transform(data, self.ops) else: outs = None if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs def __len__(self): return len(self.data_list) class SimpleDataSet(Dataset): def __init__(self, config, global_config): super(SimpleDataSet, self).__init__() delimiter = config.get('delimiter', '\t') self.data_list = get_file_list(config['file_list'], config['data_dir'], delimiter) random.shuffle(self.data_list) self.ops = create_operators(config['transforms'], global_config) # for rec character = '' for op in self.ops: if hasattr(op, 'character'): character = getattr(op, 'character') self.info_dict = {'character': character} def __getitem__(self, idx): data = copy.deepcopy(self.data_list[idx]) with open(data['img_path'], 'rb') as f: img = f.read() data['image'] = img outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__())) return outs def __len__(self): return len(self.data_list) class BatchBalancedDataLoader(object): def __init__(self, dataset_list: list, ratio_list: list, distributed, device, loader_args: dict): """ 对datasetlist里的dataset按照ratio_list里对应的比例组合,似的每个batch里的数据按按照比例采样的 :param dataset_list: 数据集列表 :param ratio_list: 比例列表 :param loader_args: dataloader的配置 """ assert sum(ratio_list) == 1 and len(dataset_list) == len(ratio_list) self.dataset_len = 0 self.data_loader_list = [] self.dataloader_iter_list = [] all_batch_size = loader_args.pop('batch_size') batch_size_list = list( map(int, [max(1.0, all_batch_size * x) for x in ratio_list])) remain_num = all_batch_size - sum(batch_size_list) batch_size_list[np.argmax(ratio_list)] += remain_num for _dataset, _batch_size in zip(dataset_list, batch_size_list): if distributed: batch_sampler_class = DistributedBatchSampler else: batch_sampler_class = BatchSampler batch_sampler = batch_sampler_class( dataset=_dataset, batch_size=_batch_size, shuffle=loader_args['shuffle'], drop_last=loader_args['drop_last'], ) _data_loader = DataLoader( dataset=_dataset, batch_sampler=batch_sampler, places=device, num_workers=loader_args['num_workers'], return_list=True, ) self.data_loader_list.append(_data_loader) self.dataloader_iter_list.append(iter(_data_loader)) self.dataset_len += len(_dataset) def __iter__(self): return self def __len__(self): return min([len(x) for x in self.data_loader_list]) def __next__(self): batch = [] for i, data_loader_iter in enumerate(self.dataloader_iter_list): try: _batch_i = next(data_loader_iter) batch.append(_batch_i) except StopIteration: self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) _batch_i = next(self.dataloader_iter_list[i]) batch.append(_batch_i) except ValueError: pass if len(batch) > 0: batch_list = [] batch_item_size = len(batch[0]) for i in range(batch_item_size): cur_item_list = [batch_i[i] for batch_i in batch] batch_list.append(paddle.concat(cur_item_list, axis=0)) else: batch_list = batch[0] return batch_list def fill_batch(batch): """ 2020.09.08: The current paddle version only supports returning data with the same length. Therefore, fill in the batches with inconsistent lengths. this method is currently only useful for text detection """ keys = list(range(len(batch[0]))) v_max_len_dict = {} for k in keys: v_max_len_dict[k] = max([len(item[k]) for item in batch]) for item in batch: length = [] for k in keys: v = item[k] length.append(len(v)) assert isinstance(v, np.ndarray) if len(v) == v_max_len_dict[k]: continue try: tmp_shape = [v_max_len_dict[k] - len(v)] + list(v[0].shape) except: a = 1 tmp_array = np.zeros(tmp_shape, dtype=v[0].dtype) new_array = np.concatenate([v, tmp_array]) item[k] = new_array item.append(length) return batch