提交 4d01af9e 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix reader

上级 e70b3351
......@@ -54,7 +54,7 @@ TRAIN:
VALID:
batch_size: 32
batch_size: 1024
num_workers: 4
file_list: "./dataset/ILSVRC2012/val_list.txt"
data_dir: "./dataset/ILSVRC2012/"
......
......@@ -106,6 +106,21 @@ def create_file_list(params):
fout.write(file_name + " 0" + "\n")
def shuffle_lines(full_lines, seed=None):
"""
random shuffle lines
Args:
full_lines(list):
seed(int): random seed
"""
if seed is not None:
np.random.RandomState(seed).shuffle(full_lines)
else:
np.random.shuffle(full_lines)
return full_lines
def get_file_list(params):
"""
read label list from file and shuffle the list
......@@ -119,6 +134,9 @@ def get_file_list(params):
with open(params['file_list']) as flist:
full_lines = [line.strip() for line in flist]
if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=params['shuffle_seed'])
return full_lines
......@@ -188,7 +206,7 @@ class Reader:
the specific reader
"""
def __init__(self, config, mode='train', seed=None, places=None):
def __init__(self, config, mode='train', places=None):
try:
self.params = config[mode.upper()]
except KeyError:
......@@ -197,8 +215,6 @@ class Reader:
use_mix = config.get('use_mix')
self.params['mode'] = mode
self.shuffle = mode == "train"
if seed is not None:
self.params['shuffle_seed'] = seed
self.collate_fn = None
self.batch_ops = []
......@@ -224,7 +240,6 @@ class Reader:
def __call__(self):
batch_size = int(self.params['batch_size']) // trainers_num
self.params['shuffle_seed'] += 1
dataset = CommonDataset(self.params)
if self.params['mode'] == "train":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册