提交 87955690 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix reader shuffle

上级 823a418d
......@@ -206,6 +206,8 @@ def mp_reader(params):
check_params(params)
full_lines = get_file_list(params)
if params["mode"] == "train":
full_lines = shuffle_lines(full_lines, seed=None)
part_num = 1 if 'num_workers' not in params else params['num_workers']
......@@ -254,11 +256,10 @@ class Reader:
self.batch_ops = create_operators(self.params['mix'])
def __call__(self):
reader = mp_reader(self.params)
batch_size = int(self.params['batch_size']) // trainers_num
def wrapper():
reader = mp_reader(self.params)
batch = []
for idx, sample in enumerate(reader()):
img, label = sample
......
......@@ -106,22 +106,20 @@ def load_params(exe, prog, path, ignore_params=[]):
fluid.io.set_program_state(prog, state)
def init_model(config, program, exe, prefix=""):
def init_model(config, program, exe):
"""
load model from checkpoint or pretrained_model
"""
checkpoints = config.get('checkpoints')
if checkpoints:
path = os.path.join(checkpoints, prefix)
fluid.load(program, path, exe)
logger.info("Finish initing model from {}".format(path))
fluid.load(program, checkpoints, exe)
logger.info("Finish initing model from {}".format(checkpoints))
return
pretrained_model = config.get('pretrained_model')
if pretrained_model:
path = os.path.join(pretrained_model, prefix)
load_params(exe, program, path)
logger.info("Finish initing model from {}".format(path))
load_params(exe, program, pretrained_model)
logger.info("Finish initing model from {}".format(pretrained_model))
def save_model(program, model_path, epoch_id, prefix='ppcls'):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册