diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index d630c21f9193f9b30bc1042675c7b3fea2780803..50e512ed52c9ff22d0e3ef4c93c2d73eaf961bb8 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -205,6 +205,12 @@ class SingleTrainer(TranspileTrainer): epochs = int(self._env["epochs"]) for j in range(epochs): for model_dict in self._env["executor"]: + if j == 0: + with fluid.scope_guard(self._model[model_dict["name"]][2]): + train_prog = self._model[model_dict["name"]][0] + startup_prog = self._model[model_dict["name"]][1] + with fluid.program_guard(train_prog, startup_prog): + self.load(j) reader_name = model_dict["dataset_name"] name = "dataset." + reader_name + "." begin_time = time.time() @@ -289,6 +295,16 @@ class SingleTrainer(TranspileTrainer): def terminal(self, context): context['is_exit'] = True + def load(self, is_fleet=False): + dirname = envs.get_global_env("epoch.init_model_path", None) + if dirname is None: + return + dirname = os.path.join(dirname, str(epoch_id)) + if is_fleet: + fleet.load_persistables(self._exe, dirname) + else: + fluid.io.load_persistables(self._exe, dirname) + def save(self, epoch_id, is_fleet=False): def need_save(epoch_id, epoch_interval, is_last=False): if is_last: