From 879556900310fef1c119bbb9e6ca296150a97be3 Mon Sep 17 00:00:00 2001 From: littletomatodonkey Date: Sun, 19 Apr 2020 04:09:03 +0000 Subject: [PATCH] fix reader shuffle --- ppcls/data/reader.py | 5 +++-- ppcls/utils/save_load.py | 12 +++++------- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/ppcls/data/reader.py b/ppcls/data/reader.py index e68ebd77..41ebed42 100755 --- a/ppcls/data/reader.py +++ b/ppcls/data/reader.py @@ -206,6 +206,8 @@ def mp_reader(params): check_params(params) full_lines = get_file_list(params) + if params["mode"] == "train": + full_lines = shuffle_lines(full_lines, seed=None) part_num = 1 if 'num_workers' not in params else params['num_workers'] @@ -254,11 +256,10 @@ class Reader: self.batch_ops = create_operators(self.params['mix']) def __call__(self): - reader = mp_reader(self.params) - batch_size = int(self.params['batch_size']) // trainers_num def wrapper(): + reader = mp_reader(self.params) batch = [] for idx, sample in enumerate(reader()): img, label = sample diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index fef1b816..986d5ac7 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -106,22 +106,20 @@ def load_params(exe, prog, path, ignore_params=[]): fluid.io.set_program_state(prog, state) -def init_model(config, program, exe, prefix=""): +def init_model(config, program, exe): """ load model from checkpoint or pretrained_model """ checkpoints = config.get('checkpoints') if checkpoints: - path = os.path.join(checkpoints, prefix) - fluid.load(program, path, exe) - logger.info("Finish initing model from {}".format(path)) + fluid.load(program, checkpoints, exe) + logger.info("Finish initing model from {}".format(checkpoints)) return pretrained_model = config.get('pretrained_model') if pretrained_model: - path = os.path.join(pretrained_model, prefix) - load_params(exe, program, path) - logger.info("Finish initing model from {}".format(path)) + load_params(exe, program, pretrained_model) + logger.info("Finish initing model from {}".format(pretrained_model)) def save_model(program, model_path, epoch_id, prefix='ppcls'): -- GitLab