未验证 提交 ebe2d6ad 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Merge pull request #47 from wangxiao1021/updatesaver

update saver
...@@ -282,13 +282,14 @@ class Trainer(object): ...@@ -282,13 +282,14 @@ class Trainer(object):
print('random init params...') print('random init params...')
self._exe.run(self._train_init_prog) self._exe.run(self._train_init_prog)
def load_pretrain(self, model_path): def load_pretrain(self, model_path, convert=False):
# load pretrain model (or ckpt) # load pretrain model (or ckpt)
assert self._exe is not None, "You need to random_init_params before load pretrain models." assert self._exe is not None, "You need to random_init_params before load pretrain models."
saver.init_pretraining_params( saver.init_pretraining_params(
self._exe, self._exe,
model_path, model_path,
convert=convert,
main_program=self._train_init_prog) main_program=self._train_init_prog)
def set_predict_head(self): def set_predict_head(self):
......
...@@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []): ...@@ -46,20 +46,24 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []):
def init_pretraining_params(exe, def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
convert,
main_program): main_program):
assert os.path.exists(pretraining_params_path assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path
if convert:
assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found."
assert os.path.exists(os.path.join(pretraining_params_path, '__palmmodel__')), "__palmmodel__ not found." with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f:
print("Loading pretraining parameters from {}...".format( f.extractall(os.path.join(pretraining_params_path, '.temp'))
pretraining_params_path))
log_path = os.path.join(pretraining_params_path, '__palmmodel__')
pretraining_params_path = os.path.join(pretraining_params_path, '.temp')
with tarfile.open(os.path.join(pretraining_params_path, '__palmmodel__'), 'r') as f: else:
f.extractall(os.path.join(pretraining_params_path, '.temp')) log_path = pretraining_params_path
log_path = os.path.join(pretraining_params_path, '__palmmodel__') print("Loading pretraining parameters from {}...".format(pretraining_params_path))
pretraining_params_path = os.path.join(pretraining_params_path, '.temp')
def existed_params(var): def existed_params(var):
if not isinstance(var, fluid.framework.Parameter): if not isinstance(var, fluid.framework.Parameter):
...@@ -73,8 +77,8 @@ def init_pretraining_params(exe, ...@@ -73,8 +77,8 @@ def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program=main_program, main_program=main_program,
predicate=existed_params) predicate=existed_params)
if convert:
shutil.rmtree(pretraining_params_path) shutil.rmtree(pretraining_params_path)
print('') print('')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册