From 4d01af9e87f50e24c35798e14efff6d3b6cd360d Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Fri, 9 Oct 2020 11:47:05 +0000 Subject: [PATCH] fix reader --- .../MobileNetV3/MobileNetV3_large_x1_0.yaml | 2 +- ppcls/data/reader.py | 23 +++++++++++++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml b/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml index e924d4ea..6f9e5597 100644 --- a/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml +++ b/configs/MobileNetV3/MobileNetV3_large_x1_0.yaml @@ -54,7 +54,7 @@ TRAIN: VALID: - batch_size: 32 + batch_size: 1024 num_workers: 4 file_list: "./dataset/ILSVRC2012/val_list.txt" data_dir: "./dataset/ILSVRC2012/" diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index 68afe127..17035a03 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -106,6 +106,21 @@ def create_file_list(params): fout.write(file_name + " 0" + "\n") +def shuffle_lines(full_lines, seed=None): + """ + random shuffle lines + Args: + full_lines(list): + seed(int): random seed + """ + if seed is not None: + np.random.RandomState(seed).shuffle(full_lines) + else: + np.random.shuffle(full_lines) + + return full_lines + + def get_file_list(params): """ read label list from file and shuffle the list @@ -119,6 +134,9 @@ def get_file_list(params): with open(params['file_list']) as flist: full_lines = [line.strip() for line in flist] + if params["mode"] == "train": + full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed']) + return full_lines @@ -188,7 +206,7 @@ class Reader: the specific reader """ - def __init__(self, config, mode='train', seed=None, places=None): + def __init__(self, config, mode='train', places=None): try: self.params = config[mode.upper()] except KeyError: @@ -197,8 +215,6 @@ class Reader: use_mix = config.get('use_mix') self.params['mode'] = mode self.shuffle = mode == "train" - if seed is not None: - self.params['shuffle_seed'] = seed self.collate_fn = None self.batch_ops = [] @@ -224,7 +240,6 @@ class Reader: def __call__(self): batch_size = int(self.params['batch_size']) // trainers_num - self.params['shuffle_seed'] += 1 dataset = CommonDataset(self.params) if self.params['mode'] == "train": -- GitLab