提交 0414a874 编写于 作者: W wangxiao1021

fix bugs

上级 9e99e1de
...@@ -78,6 +78,6 @@ if __name__ == '__main__': ...@@ -78,6 +78,6 @@ if __name__ == '__main__':
# step 8-2*: set saver to save model # step 8-2*: set saver to save model
save_steps = int(n_steps-batch_size) // 2 save_steps = int(n_steps-batch_size) // 2
# save_steps = 10 # save_steps = 10
trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True) trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type)
# step 8-3: start training # step 8-3: start training
trainer.train(print_steps=print_steps) trainer.train(print_steps=print_steps)
\ No newline at end of file
...@@ -82,7 +82,9 @@ class MultiHeadTrainer(Trainer): ...@@ -82,7 +82,9 @@ class MultiHeadTrainer(Trainer):
def get_loss(i): def get_loss(i):
head = head_dict[self._trainers[i].name] head = head_dict[self._trainers[i].name]
self._trainers[i]._lock_prog = True
loss_var = self._trainers[i].build_forward(backbone, head) loss_var = self._trainers[i].build_forward(backbone, head)
self._trainers[i]._lock_prog = False
return loss_var return loss_var
task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))} task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))}
......
...@@ -162,8 +162,10 @@ class Trainer(object): ...@@ -162,8 +162,10 @@ class Trainer(object):
train_prog = fluid.Program() train_prog = fluid.Program()
train_init_prog = fluid.Program() train_init_prog = fluid.Program()
self._train_prog = train_prog if not self._lock_prog:
self._train_init_prog = train_init_prog self._train_prog = train_prog
self._train_init_prog = train_init_prog
if not self._lock_prog: if not self._lock_prog:
with fluid.program_guard(train_prog, train_init_prog): with fluid.program_guard(train_prog, train_init_prog):
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False) net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
...@@ -505,7 +507,7 @@ class Trainer(object): ...@@ -505,7 +507,7 @@ class Trainer(object):
convert=convert, convert=convert,
main_program=self._train_init_prog) main_program=self._train_init_prog)
def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False): def set_saver(self, save_path, save_steps, save_type='ckpt'):
""" """
create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps. create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps.
...@@ -542,20 +544,11 @@ class Trainer(object): ...@@ -542,20 +544,11 @@ class Trainer(object):
if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0: if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
if self._save_predict: if self._save_predict:
if is_multi: self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
self._save(save_path, suffix='-pred.step'+str(self._cur_train_step)) print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
else:
self._save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if self._save_ckpt: if self._save_ckpt:
print(self._train_prog) fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
if is_multi: print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
else:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
return True return True
else: else:
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册