diff --git a/examples/multi-task/run.py b/examples/multi-task/run.py index c02a6cf1ce553d9eca9c32f2875c5ac3106d7f2e..aff2b4f137e258f38243aa9ff71c5cbc44162b83 100644 --- a/examples/multi-task/run.py +++ b/examples/multi-task/run.py @@ -78,6 +78,6 @@ if __name__ == '__main__': # step 8-2*: set saver to save model save_steps = int(n_steps-batch_size) // 2 # save_steps = 10 - trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type, is_multi=True) + trainer.set_saver(save_path=save_path, save_steps=save_steps, save_type=save_type) # step 8-3: start training trainer.train(print_steps=print_steps) \ No newline at end of file diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index 45550623fb1b51ebdae06641edaa837343f64fc4..39ec9711cc3c88286213461d72bc3b91329cd9ed 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -82,7 +82,9 @@ class MultiHeadTrainer(Trainer): def get_loss(i): head = head_dict[self._trainers[i].name] + self._trainers[i]._lock_prog = True loss_var = self._trainers[i].build_forward(backbone, head) + self._trainers[i]._lock_prog = False return loss_var task_fns = {i: lambda i=i: get_loss(i) for i in range(len(self._trainers))} diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 5519ad6755839677e068e14518526c2b85cd5a9c..38af875d480e346995041512ac46c74b69339355 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -162,8 +162,10 @@ class Trainer(object): train_prog = fluid.Program() train_init_prog = fluid.Program() - self._train_prog = train_prog - self._train_init_prog = train_init_prog + if not self._lock_prog: + self._train_prog = train_prog + self._train_init_prog = train_init_prog + if not self._lock_prog: with fluid.program_guard(train_prog, train_init_prog): net_inputs = reader_helper.create_net_inputs(input_attrs, async=False) @@ -505,7 +507,7 @@ class Trainer(object): convert=convert, main_program=self._train_init_prog) - def set_saver(self, save_path, save_steps, save_type='ckpt', is_multi=False): + def set_saver(self, save_path, save_steps, save_type='ckpt'): """ create a build-in saver into trainer. A saver will automatically save checkpoint or predict model every `save_steps` training steps. @@ -542,20 +544,11 @@ class Trainer(object): if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0: if self._save_predict: - if is_multi: - self._save(save_path, suffix='-pred.step'+str(self._cur_train_step)) - print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step))) - else: - self._save(save_path, suffix='pred.step'+str(self._cur_train_step)) - print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step))) + self._save(save_path, suffix='pred.step'+str(self._cur_train_step)) + print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step))) if self._save_ckpt: - print(self._train_prog) - if is_multi: - fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) - print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) - else: - fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) - print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) + fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) + print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) return True else: return False