diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index f923d5c824bb83a53256e4df71b21781a08d2232..1c886f9ca579e6bb88229dd7b252d62af52b0193 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -125,27 +125,8 @@ class MultiHeadTrainer(Trainer): branch_index=task_id_var, branch_fns=task_fns ) - # self._task_id_var = task_id_var - # self._loss_var = loss_var - # self._fetch_list = [loss_var.name] if not self._multi_task: self._init_exe_prog(for_train=False) - # return self.build_forward() - - # """ - # Build computation graph for evaluation and prediction. - - # Arguments: - # - pred_backbone: a Backbone object with phase == 'predict'. For evaluating model during training, the predict backbone should keep the same with train backbone. - # - pred_head: a Head object with phase == 'predict'. For evaluating model during training, the predict head should keep the same with train head. - # - # Return: - # - output_vars: dict type. Each value is a computational graph variable(node) argumented by pred_head outputs_attr. - # """ - # for i in self._trainers: - # assert i._predict_vars is not None, "{} need to build_predict_forward before " - # - # return output_vars def merge_inference_readers(self, readers): diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index e8996961ce9ce9c9e7f5536b43a4bf3db3f3b29f..fbc0c2aa98267a9a4f8c3ba7742db26469d3edcb 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -215,17 +215,33 @@ class Trainer(object): self._pred_name_to_position = pred_name_to_position self._pred_input_names = pred_input_names - pred_prog = fluid.Program() - self._pred_prog = pred_prog - pred_init_prog = fluid.Program() - self._pred_init_prog = pred_init_prog - with fluid.program_guard(pred_prog, pred_init_prog): + if not self._lock_prog: + pred_prog = fluid.Program() + self._pred_prog = pred_prog + pred_init_prog = fluid.Program() + self._pred_init_prog = pred_init_prog + + with fluid.program_guard(pred_prog, pred_init_prog): + pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs) + pred_bb_output_vars = pred_backbone.build(pred_net_inputs) + self._pred_net_inputs = pred_net_inputs + else: pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs) pred_bb_output_vars = pred_backbone.build(pred_net_inputs) self._pred_net_inputs = pred_net_inputs # prepare predict vars for saving inference model - with fluid.program_guard(pred_prog, pred_init_prog): + if not self._lock_prog: + with fluid.program_guard(pred_prog, pred_init_prog): + cur_inputs = helper.decode_inputs(pred_net_inputs, self.name) + self._pred_input_name_list, self._pred_input_varname_list = \ + zip(*[[k, v.name] for k,v in cur_inputs.items()]) + + pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} + scope = self.name + '.' + with fluid.unique_name.guard(scope): + output_vars = self._build_head(pred_task_inputs, phase='predict', scope=scope) + else: cur_inputs = helper.decode_inputs(pred_net_inputs, self.name) self._pred_input_name_list, self._pred_input_varname_list = \ zip(*[[k, v.name] for k,v in cur_inputs.items()]) @@ -385,20 +401,32 @@ class Trainer(object): """ assert self._train_init_prog is not None or self._pred_init_prog is not None, "model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint." + # if self._train_init_prog is not None: + # saver.init_pretraining_params( + # self._exe, + # model_path, + # convert=False, + # main_program=self._train_init_prog, + # strict=True) + # elif self._pred_init_prog is not None: + # saver.init_pretraining_params( + # self._exe, + # model_path, + # convert=False, + # main_program=self._pred_init_prog, + # strict=True) if self._train_init_prog is not None: - saver.init_pretraining_params( + print('loading checkpoint into train program') + saver.init_checkpoint( self._exe, model_path, convert=False, - main_program=self._train_init_prog, - strict=True) + main_program=self._train_init_prog) elif self._pred_init_prog is not None: - saver.init_pretraining_params( + saver.init_checkpoint( self._exe, model_path, - convert=False, - main_program=self._pred_init_prog, - strict=True) + main_program=self._pred_init_prog) else: raise Exception("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.") @@ -529,6 +557,7 @@ class Trainer(object): iterator = self._predict_iterator self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel() + if output_dir is not None and not os.path.exists(output_dir): os.makedirs(output_dir)