提交 4d01af9e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix reader

上级 e70b3351
...@@ -54,7 +54,7 @@ TRAIN: ...@@ -54,7 +54,7 @@ TRAIN:
VALID: VALID:
batch_size: 32 batch_size: 1024
num_workers: 4 num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt" file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/" data_dir: "./dataset/ILSVRC2012/"
......
...@@ -106,6 +106,21 @@ def create_file_list(params): ...@@ -106,6 +106,21 @@ def create_file_list(params):
fout.write(file_name + " 0" + "\n") fout.write(file_name + " 0" + "\n")
def shuffle_lines(full_lines, seed=None):
"""
random shuffle lines
Args:
full_lines(list):
seed(int): random seed
"""
if seed is not None:
np.random.RandomState(seed).shuffle(full_lines)
else:
np.random.shuffle(full_lines)
return full_lines
def get_file_list(params): def get_file_list(params):
""" """
read label list from file and shuffle the list read label list from file and shuffle the list
...@@ -119,6 +134,9 @@ def get_file_list(params): ...@@ -119,6 +134,9 @@ def get_file_list(params):
with open(params['file_list']) as flist: with open(params['file_list']) as flist:
full_lines = [line.strip() for line in flist] full_lines = [line.strip() for line in flist]
if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed'])
return full_lines return full_lines
...@@ -188,7 +206,7 @@ class Reader: ...@@ -188,7 +206,7 @@ class Reader:
the specific reader the specific reader
""" """
def __init__(self, config, mode='train', seed=None, places=None): def __init__(self, config, mode='train', places=None):
try: try:
self.params = config[mode.upper()] self.params = config[mode.upper()]
except KeyError: except KeyError:
...@@ -197,8 +215,6 @@ class Reader: ...@@ -197,8 +215,6 @@ class Reader:
use_mix = config.get('use_mix') use_mix = config.get('use_mix')
self.params['mode'] = mode self.params['mode'] = mode
self.shuffle = mode == "train" self.shuffle = mode == "train"
if seed is not None:
self.params['shuffle_seed'] = seed
self.collate_fn = None self.collate_fn = None
self.batch_ops = [] self.batch_ops = []
...@@ -224,7 +240,6 @@ class Reader: ...@@ -224,7 +240,6 @@ class Reader:
def __call__(self): def __call__(self):
batch_size = int(self.params['batch_size']) // trainers_num batch_size = int(self.params['batch_size']) // trainers_num
self.params['shuffle_seed'] += 1
dataset = CommonDataset(self.params) dataset = CommonDataset(self.params)
if self.params['mode'] == "train": if self.params['mode'] == "train":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册