diff --git a/PaddleSlim/reader.py b/PaddleSlim/reader.py index f4a9da1e03102fceedb2bd295ebb399e19e93361..e7dc21b7024458d0bdbe5a5f58cece3a148055f5 100644 --- a/PaddleSlim/reader.py +++ b/PaddleSlim/reader.py @@ -133,33 +133,36 @@ def _reader_creator(file_list, data_dir=DATA_DIR, batch_size=1): def reader(): - with open(file_list) as flist: - full_lines = [line.strip() for line in flist] - if shuffle: - np.random.shuffle(full_lines) - if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): - # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits - trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) - trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) - per_node_lines = len(full_lines) // trainer_count - lines = full_lines[trainer_id * per_node_lines:(trainer_id + 1) - * per_node_lines] - print( - "read images from %d, length: %d, lines length: %d, total: %d" - % (trainer_id * per_node_lines, per_node_lines, len(lines), - len(full_lines))) - else: - lines = full_lines - - for line in lines: - if mode == 'train' or mode == 'val': - img_path, label = line.split() - # img_path = img_path.replace("JPEG", "jpeg") - img_path = os.path.join(data_dir, img_path) - yield img_path, int(label) - elif mode == 'test': - img_path = os.path.join(data_dir, line) - yield [img_path] + try: + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) // trainer_count + lines = full_lines[trainer_id * per_node_lines:( + trainer_id + 1) * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, + len(lines), len(full_lines))) + else: + lines = full_lines + + for line in lines: + if mode == 'train' or mode == 'val': + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + yield img_path, int(label) + elif mode == 'test': + img_path = os.path.join(data_dir, line) + yield [img_path] + except Exception as e: + print("Reader failed!\n{}".format(str(e))) + os._exit(1) mapper = functools.partial( process_image, mode=mode, color_jitter=color_jitter, rotate=rotate)