提交 b4253934 编写于 作者: X xjqbest

fix

上级 bb7af6ed
...@@ -205,6 +205,12 @@ class SingleTrainer(TranspileTrainer): ...@@ -205,6 +205,12 @@ class SingleTrainer(TranspileTrainer):
epochs = int(self._env["epochs"]) epochs = int(self._env["epochs"])
for j in range(epochs): for j in range(epochs):
for model_dict in self._env["executor"]: for model_dict in self._env["executor"]:
if j == 0:
with fluid.scope_guard(self._model[model_dict["name"]][2]):
train_prog = self._model[model_dict["name"]][0]
startup_prog = self._model[model_dict["name"]][1]
with fluid.program_guard(train_prog, startup_prog):
self.load(j)
reader_name = model_dict["dataset_name"] reader_name = model_dict["dataset_name"]
name = "dataset." + reader_name + "." name = "dataset." + reader_name + "."
begin_time = time.time() begin_time = time.time()
...@@ -289,6 +295,16 @@ class SingleTrainer(TranspileTrainer): ...@@ -289,6 +295,16 @@ class SingleTrainer(TranspileTrainer):
def terminal(self, context): def terminal(self, context):
context['is_exit'] = True context['is_exit'] = True
def load(self, is_fleet=False):
dirname = envs.get_global_env("epoch.init_model_path", None)
if dirname is None:
return
dirname = os.path.join(dirname, str(epoch_id))
if is_fleet:
fleet.load_persistables(self._exe, dirname)
else:
fluid.io.load_persistables(self._exe, dirname)
def save(self, epoch_id, is_fleet=False): def save(self, epoch_id, is_fleet=False):
def need_save(epoch_id, epoch_interval, is_last=False): def need_save(epoch_id, epoch_interval, is_last=False):
if is_last: if is_last:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册