From b42539347acb7feafcc4e0c536d5ce5b05d247a1 Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Wed, 27 May 2020 23:42:54 +0800 Subject: [PATCH] fix --- core/trainers/single_trainer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index d630c21f..50e512ed 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -205,6 +205,12 @@ class SingleTrainer(TranspileTrainer): epochs = int(self._env["epochs"]) for j in range(epochs): for model_dict in self._env["executor"]: + if j == 0: + with fluid.scope_guard(self._model[model_dict["name"]][2]): + train_prog = self._model[model_dict["name"]][0] + startup_prog = self._model[model_dict["name"]][1] + with fluid.program_guard(train_prog, startup_prog): + self.load(j) reader_name = model_dict["dataset_name"] name = "dataset." + reader_name + "." begin_time = time.time() @@ -289,6 +295,16 @@ class SingleTrainer(TranspileTrainer): def terminal(self, context): context['is_exit'] = True + def load(self, is_fleet=False): + dirname = envs.get_global_env("epoch.init_model_path", None) + if dirname is None: + return + dirname = os.path.join(dirname, str(epoch_id)) + if is_fleet: + fleet.load_persistables(self._exe, dirname) + else: + fluid.io.load_persistables(self._exe, dirname) + def save(self, epoch_id, is_fleet=False): def need_save(epoch_id, epoch_interval, is_last=False): if is_last: -- GitLab