提交 bbbb7357 编写于 作者: X xixiaoyao

fix load ckpt and train one step

上级 033906aa
...@@ -297,6 +297,7 @@ class MultiHeadTrainer(Trainer): ...@@ -297,6 +297,7 @@ class MultiHeadTrainer(Trainer):
iterator = self._train_reader iterator = self._train_reader
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
for t in self._trainers: for t in self._trainers:
t._dist_train_init = True
t._set_exe(self._exe) t._set_exe(self._exe)
t._set_dist_train(self._distribute_train_prog) t._set_dist_train(self._distribute_train_prog)
t._set_fetch_list(self._fetch_list) t._set_fetch_list(self._fetch_list)
...@@ -332,6 +333,7 @@ class MultiHeadTrainer(Trainer): ...@@ -332,6 +333,7 @@ class MultiHeadTrainer(Trainer):
if not self._dist_train_init: if not self._dist_train_init:
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
for t in self._trainers: for t in self._trainers:
t._dist_train_init = True
t._set_exe(self._exe) t._set_exe(self._exe)
t._set_dist_train(self._distribute_train_prog) t._set_dist_train(self._distribute_train_prog)
t._set_fetch_list(self._fetch_list) t._set_fetch_list(self._fetch_list)
......
...@@ -420,7 +420,6 @@ class Trainer(object): ...@@ -420,7 +420,6 @@ class Trainer(object):
saver.init_checkpoint( saver.init_checkpoint(
self._exe, self._exe,
model_path, model_path,
convert=False,
main_program=self._train_init_prog) main_program=self._train_init_prog)
elif self._pred_init_prog is not None: elif self._pred_init_prog is not None:
saver.init_checkpoint( saver.init_checkpoint(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册