From 5090515b8d9ab8ceba082c88ec46162a9039f3ed Mon Sep 17 00:00:00 2001 From: xjqbest <173596896@qq.com> Date: Thu, 28 May 2020 12:07:36 +0800 Subject: [PATCH] fix --- core/trainers/single_trainer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/core/trainers/single_trainer.py b/core/trainers/single_trainer.py index db1bc8ef..d39d48f3 100755 --- a/core/trainers/single_trainer.py +++ b/core/trainers/single_trainer.py @@ -210,7 +210,7 @@ class SingleTrainer(TranspileTrainer): train_prog = self._model[model_dict["name"]][0] startup_prog = self._model[model_dict["name"]][1] with fluid.program_guard(train_prog, startup_prog): - self.load(j) + self.load() reader_name = model_dict["dataset_name"] name = "dataset." + reader_name + "." begin_time = time.time() @@ -283,7 +283,7 @@ class SingleTrainer(TranspileTrainer): metrics_format = [] fetch_period = 20 metrics_format.append("{}: {{}}".format("batch")) - for name, var in model_class.get_metrics().items(): + for name, var in metrics.items(): metrics_varnames.append(var.name) metrics_format.append("{}: {{}}".format(name)) metrics_format = ", ".join(metrics_format) @@ -313,6 +313,7 @@ class SingleTrainer(TranspileTrainer): dirname = envs.get_global_env("epoch.init_model_path", None) if dirname is None: return + print("going to load ", dirname) if is_fleet: fleet.load_persistables(self._exe, dirname) else: @@ -334,7 +335,7 @@ class SingleTrainer(TranspileTrainer): return feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None) fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None) - if feed_varnames is None or fetch_varnames is None: + if feed_varnames is None or fetch_varnames is None or feed_varnames == "": return fetch_vars = [ fluid.default_main_program().global_block().vars[varname] @@ -358,7 +359,8 @@ class SingleTrainer(TranspileTrainer): if not need_save(epoch_id, save_interval, False): return dirname = envs.get_global_env("epoch.save_checkpoint_path", None) - assert dirname is not None + if dirname is None or dirname == "": + return dirname = os.path.join(dirname, str(epoch_id)) if is_fleet: fleet.save_persistables(self._exe, dirname) -- GitLab