提交 5090515b 编写于 作者: X xjqbest

fix

上级 eabfd85d
......@@ -210,7 +210,7 @@ class SingleTrainer(TranspileTrainer):
train_prog = self._model[model_dict["name"]][0]
startup_prog = self._model[model_dict["name"]][1]
with fluid.program_guard(train_prog, startup_prog):
self.load(j)
self.load()
reader_name = model_dict["dataset_name"]
name = "dataset." + reader_name + "."
begin_time = time.time()
......@@ -283,7 +283,7 @@ class SingleTrainer(TranspileTrainer):
metrics_format = []
fetch_period = 20
metrics_format.append("{}: {{}}".format("batch"))
for name, var in model_class.get_metrics().items():
for name, var in metrics.items():
metrics_varnames.append(var.name)
metrics_format.append("{}: {{}}".format(name))
metrics_format = ", ".join(metrics_format)
......@@ -313,6 +313,7 @@ class SingleTrainer(TranspileTrainer):
dirname = envs.get_global_env("epoch.init_model_path", None)
if dirname is None:
return
print("going to load ", dirname)
if is_fleet:
fleet.load_persistables(self._exe, dirname)
else:
......@@ -334,7 +335,7 @@ class SingleTrainer(TranspileTrainer):
return
feed_varnames = envs.get_global_env("epoch.save_inference_feed_varnames", None)
fetch_varnames = envs.get_global_env("epoch.save_inference_fetch_varnames", None)
if feed_varnames is None or fetch_varnames is None:
if feed_varnames is None or fetch_varnames is None or feed_varnames == "":
return
fetch_vars = [
fluid.default_main_program().global_block().vars[varname]
......@@ -358,7 +359,8 @@ class SingleTrainer(TranspileTrainer):
if not need_save(epoch_id, save_interval, False):
return
dirname = envs.get_global_env("epoch.save_checkpoint_path", None)
assert dirname is not None
if dirname is None or dirname == "":
return
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.save_persistables(self._exe, dirname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册