diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 135529ac368cc825f43f251c078f812ac64acfd9..2ae106f952f4a6917e1110d1ff98c86eadea88f8 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -40,6 +40,9 @@ class Trainer(object): self._task_head = task_head self._pred_head = None + self._train_init = False + self._predict_init = False + # if save_predict_model: # self._save_predict_model = True # assert pred_head is not None, "pred_head is required to save predict model." @@ -220,7 +223,7 @@ class Trainer(object): for _id, block in enumerate(self._train_prog.blocks): for var in block.vars: print("[debug] : %d, %s" % (_id, var)) - + self._loss_var = loss_var return loss_var def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999): @@ -296,30 +299,44 @@ class Trainer(object): distribute_feeder_fn = iterator_fn return distribute_feeder_fn() - def random_init_params(self): + def _init_exe_prog(self, for_train=True): assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters." - self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=loss_var.name) + self._train_init = True + self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) on_gpu = gpu_dev_count > 0 self._exe = helper.build_executor(on_gpu) + if not for_train: + raise NotImplementedError() + + def random_init_params(self): + if not self._train_init: + self._init_exe_prog() + print('random init params...') self._exe.run(self._train_init_prog) def load_ckpt(self, model_path, phase='train'): # load pretrain model (or ckpt) - assert self._exe is not None, "You need to random_init_params before load checkpoints." + # assert self._exe is not None, "You need to random_init_params before load checkpoints." + if phase == 'train' and not self._train_init: + self._init_exe_prog() + if phase == 'predict' and not self._predict_init: + pass if phase == 'train': assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint." saver.init_pretraining_params( self._exe, model_path, - main_program=self._train_init_prog) + main_program=self._train_init_prog, + strict=True) elif phase == 'predict': assert self._pred_init_prog is not None, "predict graph not found! You should build_predict_head first before load checkpoint." saver.init_pretraining_params( self._exe, model_path, - main_program=self._pred_init_prog) + main_program=self._pred_init_prog, + strict=True) else: raise NotImplementedError() @@ -397,6 +414,11 @@ class Trainer(object): task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} self._task_head.postprocess(task_rt_outputs) + # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} + + task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} + self._task_head.postprocess(task_rt_outputs) + self._cur_train_step += 1 self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch @@ -578,11 +600,6 @@ class Trainer(object): # self._cur_train_step = 1 # if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps: # self._train_finish = True - - @property - def steps_pur_epoch(self): - return self._steps_pur_epoch - @steps_pur_epoch.setter def steps_pur_epoch(self, value): self._steps_pur_epoch = value diff --git a/paddlepalm/utils/saver.py b/paddlepalm/utils/saver.py index ab6d288900a807f05153e38c999802c1e8facfe8..b4f0241e2c04d6d462a1afd1bfb6ca2e4cb2179a 100644 --- a/paddlepalm/utils/saver.py +++ b/paddlepalm/utils/saver.py @@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []): def init_pretraining_params(exe, pretraining_params_path, + convert, main_program): assert os.path.exists(pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path + if convert: + assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found." - assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found." - print("Loading pretraining parameters from {}...".format( - pretraining_params_path)) + with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f: + f.extractall(os.path.join(pretraining_params_path, '.temp')) + + log_path = os.path.join(pretraining_params_path, '__palmmodel__') + pretraining_params_path = os.path.join(pretraining_params_path, '.temp') - with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f: - f.extractall(os.path.join(pretraining_params_path, '.temp')) + else: + log_path = pretraining_params_path - log_path = os.path.join(pretraining_params_path, '__palmmodel__') - pretraining_params_path = os.path.join(pretraining_params_path, '.temp') + print("Loading pretraining parameters from {}...".format(pretraining_params_path)) def existed_params(var): if not isinstance(var, fluid.framework.Parameter): @@ -73,8 +77,8 @@ def init_pretraining_params(exe, pretraining_params_path, main_program=main_program, predicate=existed_params) - - shutil.rmtree(pretraining_params_path) + if convert: + shutil.rmtree(pretraining_params_path) print('')