From 5234589ff2b70f6737d8978e46936c09a819c0a4 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Thu, 9 Apr 2020 02:59:19 +0000 Subject: [PATCH] add save stage --- trainer/single_train.py | 56 +++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 25 deletions(-) diff --git a/trainer/single_train.py b/trainer/single_train.py index 49b76350..bab3d32f 100644 --- a/trainer/single_train.py +++ b/trainer/single_train.py @@ -31,17 +31,13 @@ logger = logging.getLogger("fluid") logger.setLevel(logging.INFO) -def need_save(epoch_id, epoch_interval, is_last=False): - if is_last: - return True - - return epoch_id % epoch_interval == 0 - - class SingleTrainer(Trainer): def __init__(self, config=None): Trainer.__init__(self, config) + self.inference_models = [] + self.increment_models = [] + self.exe = fluid.Executor(fluid.CPUPlace()) self.regist_context_processor('uninit', self.instance) @@ -116,44 +112,48 @@ class SingleTrainerWithDataset(SingleTrainer): dataset.set_filelist(file_list) return dataset - def save(self, epoch_id): + def save(self, epoch_id, namespace): + def need_save(epoch_id, epoch_interval, is_last=False): + if is_last: + return True + + if epoch_id == -1: + return False + + return epoch_id % epoch_interval == 0 + def save_inference_model(): - is_save_inference = envs.get_global_env("save.inference", False) - if not is_save_inference: - return + save_interval = envs.get_global_env("save.inference.epoch_interval", -1, namespace) - save_interval = envs.get_global_env("save.inference.epoch_interval", 1) if not need_save(epoch_id, save_interval, False): return - feed_varnames = envs.get_global_env("save.inference.feed_varnames", None) - fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None) + print("save inference model is not supported now.") + return + + feed_varnames = envs.get_global_env("save.inference.feed_varnames", None, namespace) + fetch_varnames = envs.get_global_env("save.inference.fetch_varnames", None, namespace) fetch_vars = [fluid.global_scope().vars[varname] for varname in fetch_varnames] - dirname = envs.get_global_env("save.inference.dirname", None) + dirname = envs.get_global_env("save.inference.dirname", None, namespace) assert dirname is not None dirname = os.path.join(dirname, str(epoch_id)) fluid.io.save_inference_model(dirname, feed_varnames, fetch_vars, self.exe) + self.inference_models.append((epoch_id, dirname)) + def save_persistables(): - is_save_increment = envs.get_global_env("save.increment", False) - if not is_save_increment: - return + save_interval = envs.get_global_env("save.increment.epoch_interval", -1, namespace) - save_interval = envs.get_global_env("save.increment.epoch_interval", 1) if not need_save(epoch_id, save_interval, False): return - dirname = envs.get_global_env("save.inference.dirname", None) + dirname = envs.get_global_env("save.increment.dirname", None, namespace) assert dirname is not None dirname = os.path.join(dirname, str(epoch_id)) fluid.io.save_persistables(self.exe, dirname) - - is_save = envs.get_global_env("save", False) - - if not is_save: - return + self.increment_models.append((epoch_id, dirname)) save_persistables() save_inference_model() @@ -169,8 +169,14 @@ class SingleTrainerWithDataset(SingleTrainer): fetch_list=self.metric_extras[0], fetch_info=self.metric_extras[1], print_period=self.metric_extras[2]) + self.save(i, "train") context['status'] = 'infer_pass' def infer(self, context): context['status'] = 'terminal_pass' + + def terminal(self, context): + for model in self.increment_models: + print("epoch :{}, dir: {}".format(model[0], model[1])) + context['is_exit'] = True -- GitLab