提交 bbbb7357 编写于 作者: X xixiaoyao

fix load ckpt and train one step

上级 033906aa
......@@ -297,6 +297,7 @@ class MultiHeadTrainer(Trainer):
iterator = self._train_reader
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
for t in self._trainers:
t._dist_train_init = True
t._set_exe(self._exe)
t._set_dist_train(self._distribute_train_prog)
t._set_fetch_list(self._fetch_list)
......@@ -332,6 +333,7 @@ class MultiHeadTrainer(Trainer):
if not self._dist_train_init:
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
for t in self._trainers:
t._dist_train_init = True
t._set_exe(self._exe)
t._set_dist_train(self._distribute_train_prog)
t._set_fetch_list(self._fetch_list)
......
......@@ -420,7 +420,6 @@ class Trainer(object):
saver.init_checkpoint(
self._exe,
model_path,
convert=False,
main_program=self._train_init_prog)
elif self._pred_init_prog is not None:
saver.init_checkpoint(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册