# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import numpy as np import imghdr import os import sys import signal from paddle import fluid from paddle.fluid.io import multiprocess_reader from . import imaug from .imaug import transform from ppcls.utils import logger trainers_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 0)) trainer_id = int(os.environ.get("PADDLE_TRAINER_ID", 0)) class ModeException(Exception): """ ModeException """ def __init__(self, message='', mode=''): message += "\nOnly the following 3 modes are supported: " \ "train, valid, test. Given mode is {}".format(mode) super(ModeException, self).__init__(message) class SampleNumException(Exception): """ SampleNumException """ def __init__(self, message='', sample_num=0, batch_size=1): message += "\nError: The number of the whole data ({}) " \ "is smaller than the batch_size ({}), and drop_last " \ "is turnning on, so nothing will feed in program, " \ "Terminated now. Please reset batch_size to a smaller " \ "number or feed more data!".format(sample_num, batch_size) super(SampleNumException, self).__init__(message) class ShuffleSeedException(Exception): """ ShuffleSeedException """ def __init__(self, message=''): message += "\nIf trainers_num > 1, the shuffle_seed must be set, " \ "because the order of batch data generated by reader " \ "must be the same in the respective processes." super(ShuffleSeedException, self).__init__(message) def check_params(params): """ check params to avoid unexpect errors Args: params(dict): """ if 'shuffle_seed' not in params: params['shuffle_seed'] = None if trainers_num > 1 and params['shuffle_seed'] is None: raise ShuffleSeedException() data_dir = params.get('data_dir', '') assert os.path.isdir(data_dir), \ "{} doesn't exist, please check datadir path".format(data_dir) if params['mode'] != 'test': file_list = params.get('file_list', '') assert os.path.isfile(file_list), \ "{} doesn't exist, please check file list path".format(file_list) def create_file_list(params): """ if mode is test, create the file list Args: params(dict): """ data_dir = params.get('data_dir', '') params['file_list'] = ".tmp.txt" imgtype_list = {'jpg', 'bmp', 'png', 'jpeg', 'rgb', 'tif', 'tiff'} with open(params['file_list'], "w", encoding='utf-8') as fout: tmp_file_list = os.listdir(data_dir) for file_name in tmp_file_list: file_path = os.path.join(data_dir, file_name) if imghdr.what(file_path) not in imgtype_list: continue fout.write(file_name + " 0" + "\n") def shuffle_lines(full_lines, seed=None): """ random shuffle lines Args: full_lines(list): seed(int): random seed """ if seed is not None: np.random.RandomState(seed).shuffle(full_lines) else: np.random.shuffle(full_lines) return full_lines def get_file_list(params): """ read label list from file and shuffle the list Args: params(dict): """ if params['mode'] == 'test': create_file_list(params) with open(params['file_list'], encoding='utf-8') as flist: full_lines = [line.strip() for line in flist] full_lines = shuffle_lines(full_lines, params["shuffle_seed"]) # use only partial data for each trainer in distributed training if params['mode'] == 'train': real_trainer_num = max(trainers_num, 1) img_per_trainer = len(full_lines) // real_trainer_num full_lines = full_lines[trainer_id::real_trainer_num][:img_per_trainer] return full_lines def create_operators(params): """ create operators based on the config Args: params(list): a dict list, used to create some operators """ assert isinstance(params, list), ('operator config should be a list') ops = [] for operator in params: assert isinstance(operator, dict) and len(operator) == 1, "yaml format error" op_name = list(operator)[0] param = {} if operator[op_name] is None else operator[op_name] op = getattr(imaug, op_name)(**param) ops.append(op) return ops def partial_reader(params, full_lines, part_id=0, part_num=1, batch_size=1): """ create a reader with partial data Args: params(dict): full_lines: label list part_id(int): part index of the current partial data part_num(int): part num of the dataset batch_size(int): batch size for one trainer """ assert part_id < part_num, ("part_num: {} should be larger " "than part_id: {}".format(part_num, part_id)) full_lines = full_lines[part_id::part_num] if params['mode'] != "test" and len(full_lines) < batch_size: raise SampleNumException('', len(full_lines), batch_size) def reader(): ops = create_operators(params['transforms']) delimiter = params.get('delimiter', ' ') for line in full_lines: img_path, label = line.split(delimiter) img_path = os.path.join(params['data_dir'], img_path) with open(img_path, 'rb') as f: img = f.read() yield (transform(img, ops), int(label)) return reader def mp_reader(params, batch_size): """ multiprocess reader Args: params(dict): """ check_params(params) full_lines = get_file_list(params) if params["mode"] == "train": full_lines = shuffle_lines(full_lines, seed=None) # NOTE: multiprocess reader is not supported on windows if sys.platform == "win32": return partial_reader(params, full_lines, 0, 1, batch_size) part_num = 1 if 'num_workers' not in params else params['num_workers'] readers = [] for part_id in range(part_num): readers.append( partial_reader(params, full_lines, part_id, part_num, batch_size)) return multiprocess_reader(readers, use_pipe=False) def term_mp(sig_num, frame): """ kill all child processes """ pid = os.getpid() pgid = os.getpgid(os.getpid()) logger.info("main proc {} exit, kill process group " "{}".format(pid, pgid)) os.killpg(pgid, signal.SIGKILL) class Reader: """ Create a reader for trainning/validate/test Args: config(dict): arguments mode(str): train or val or test seed(int): random seed used to generate same sequence in each trainer Returns: the specific reader """ def __init__(self, config, mode='train', seed=None): try: self.params = config[mode.upper()] except KeyError: raise ModeException(mode=mode) self.use_gpu = config.get("use_gpu", True) use_mix = config.get('use_mix') self.params['mode'] = mode if seed is not None: self.params['shuffle_seed'] = seed self.batch_ops = [] if use_mix and mode == "train": self.batch_ops = create_operators(self.params['mix']) def __call__(self): device_num = trainers_num # non-distributed launch if trainers_num <= 0: if self.use_gpu: device_num = fluid.core.get_cuda_device_count() else: device_num = int(os.environ.get('CPU_NUM', 1)) batch_size = int(self.params['batch_size']) // device_num def wrapper(): reader = mp_reader(self.params, batch_size) batch = [] for idx, sample in enumerate(reader()): img, label = sample batch.append((img, label)) if (idx + 1) % batch_size == 0: batch = transform(batch, self.batch_ops) yield batch batch = [] return wrapper signal.signal(signal.SIGINT, term_mp) signal.signal(signal.SIGTERM, term_mp)