diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 01be7e14d0ebadfc0d93c07a7d56d251c6e5c2e8..e931436f5e258b5115a9af1c60da83a541a21003 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -282,13 +282,14 @@ class Trainer(object): print('random init params...') self._exe.run(self._train_init_prog) - def load_pretrain(self, model_path): + def load_pretrain(self, model_path, convert=False): # load pretrain model (or ckpt) assert self._exe is not None, "You need to random_init_params before load pretrain models." saver.init_pretraining_params( self._exe, model_path, + convert=convert, main_program=self._train_init_prog) def set_predict_head(self):