From 7b96751c67355d2172fadc9749129a7663b15be6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E8=82=96?= Date: Fri, 10 Jan 2020 13:10:41 +0800 Subject: [PATCH] update trainer.py --- paddlepalm/trainer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 01be7e1..e931436 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -282,13 +282,14 @@ class Trainer(object): print('random init params...') self._exe.run(self._train_init_prog) - def load_pretrain(self, model_path): + def load_pretrain(self, model_path, convert=False): # load pretrain model (or ckpt) assert self._exe is not None, "You need to random_init_params before load pretrain models." saver.init_pretraining_params( self._exe, model_path, + convert=convert, main_program=self._train_init_prog) def set_predict_head(self): -- GitLab