未验证 提交 7b96751c 编写于 作者: 王肖 提交者: GitHub

update trainer.py

上级 2fd3edb8
...@@ -282,13 +282,14 @@ class Trainer(object): ...@@ -282,13 +282,14 @@ class Trainer(object):
print('random init params...') print('random init params...')
self._exe.run(self._train_init_prog) self._exe.run(self._train_init_prog)
def load_pretrain(self, model_path): def load_pretrain(self, model_path, convert=False):
# load pretrain model (or ckpt) # load pretrain model (or ckpt)
assert self._exe is not None, "You need to random_init_params before load pretrain models." assert self._exe is not None, "You need to random_init_params before load pretrain models."
saver.init_pretraining_params( saver.init_pretraining_params(
self._exe, self._exe,
model_path, model_path,
convert=convert,
main_program=self._train_init_prog) main_program=self._train_init_prog)
def set_predict_head(self): def set_predict_head(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册