提交 5090515b 编写于 作者: X xjqbest

fix

上级 eabfd85d
...@@ -210,7 +210,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -210,7 +210,7 @@ class SingleTrainer(TranspileTrainer):
train_prog = self._model[model_dict["name"]][0] train_prog = self._model[model_dict["name"]][0]
startup_prog = self._model[model_dict["name"]][1] startup_prog = self._model[model_dict["name"]][1]
with fluid.program_guard(train_prog, startup_prog): with fluid.program_guard(train_prog, startup_prog):
self.load(j) self.load()
reader_name = model_dict["dataset_name"] reader_name = model_dict["dataset_name"]
name = "dataset." + reader_name + "." name = "dataset." + reader_name + "."
begin_time = time.time() begin_time = time.time()
...@@ -283,7 +283,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -283,7 +283,7 @@ class SingleTrainer(TranspileTrainer):
metrics_format = [] metrics_format = []
fetch_period = 20 fetch_period = 20
metrics_format.append("{}: {{}}".format("batch")) metrics_format.append("{}: {{}}".format("batch"))
for name, var in model_class.get_metrics().items(): for name, var in metrics.items():
metrics_varnames.append(var.name) metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name)) metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format) metrics_format = ", ".join(metrics_format)
...@@ -313,6 +313,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -313,6 +313,7 @@ class SingleTrainer(TranspileTrainer):
dirname = envs.get_global_env("epoch.init_model_path", None) dirname = envs.get_global_env("epoch.init_model_path", None)
if dirname is None: if dirname is None:
return return
print("going to load ", dirname)
if is_fleet: if is_fleet:
fleet.load_persistables(self._exe, dirname) fleet.load_persistables(self._exe, dirname)
else: else:
...@@ -334,7 +335,7 @@ class SingleTrainer(TranspileTrainer): ...@@ -334,7 +335,7 @@ class SingleTrainer(TranspileTrainer):
return return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None) feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None)
fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None) fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None: if feed_varnames is None or fetch_varnames is None or feed_varnames == "":
return return
fetch_vars = [ fetch_vars = [
fluid.default_main_program().global_block().vars[varname] fluid.default_main_program().global_block().vars[varname]
...@@ -358,7 +359,8 @@ class SingleTrainer(TranspileTrainer): ...@@ -358,7 +359,8 @@ class SingleTrainer(TranspileTrainer):
if not need_save(epoch_id, save_interval, False): if not need_save(epoch_id, save_interval, False):
return return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None) dirname = envs.get_global_env("epoch.save_checkpoint_path", None)
assert dirname is not None if dirname is None or dirname == "":
return
dirname = os.path.join(dirname, str(epoch_id)) dirname = os.path.join(dirname, str(epoch_id))
if is_fleet: if is_fleet:
fleet.save_persistables(self._exe, dirname) fleet.save_persistables(self._exe, dirname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册