提交 0414a874 编写于 作者: W wangxiao1021

fix bugs

上级 9e99e1de
......@@ -78,6 +78,6 @@ if __name__ == '__main__':
# step 8-2*: set saver to save model
save_steps = int(n_steps-batch_size) // 2
# save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True)
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training
trainer.train(print_steps=print_steps)
\ No newline at end of file
......@@ -82,7 +82,9 @@ class MultiHeadTrainer(Trainer):
def get_loss(i):
head = head_dict[self._trainers[i].name]
self._trainers[i]._lock_prog = True
loss_var = self._trainers[i].build_forward(backbone, head)
self._trainers[i]._lock_prog = False
return loss_var
task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))}
......
......@@ -162,8 +162,10 @@ class Trainer(object):
train_prog = fluid.Program()
train_init_prog = fluid.Program()
if not self._lock_prog:
self._train_prog = train_prog
self._train_init_prog = train_init_prog
if not self._lock_prog:
with fluid.program_guard(train_prog, train_init_prog):
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
......@@ -505,7 +507,7 @@ class Trainer(object):
convert=convert,
main_program=self._train_init_prog)
def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False):
def set_saver(self, save_path, save_steps, save_type='ckpt'):
"""
create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.
......@@ -542,18 +544,9 @@ class Trainer(object):
if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
if self._save_predict:
if is_multi:
self._save(save_path, suffix='-pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
else:
self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if self._save_ckpt:
print(self._train_prog)
if is_multi:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
else:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册