未验证 提交 4d28bb18 编写于 作者: X Xing Wu 提交者: GitHub

update save/load for v1.7 (#4332)

* update save/load for v1.7

* update load api

* update load api
上级 d0937f0b
...@@ -127,7 +127,8 @@ def infer(): ...@@ -127,7 +127,8 @@ def infer():
dir_name = args.reload_model dir_name = args.reload_model
print("dir name", dir_name) print("dir name", dir_name)
fluid.io.load_params(exe, dir_name) dir_name = os.path.join(dir_name, "checkpoint")
fluid.load(main_program, dir_name, exe)
train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval') train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval')
......
...@@ -229,10 +229,11 @@ def main(): ...@@ -229,10 +229,11 @@ def main():
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times))) % (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
if not args.profile: if not args.profile:
dir_name = os.path.join(args.model_path, save_path = os.path.join(args.model_path,
"epoch_" + str(epoch_id)) "epoch_" + str(epoch_id),
print("begin to save", dir_name) "checkpoint")
fluid.io.save_params(exe, dir_name, main_program=train_program) print("begin to save", save_path)
fluid.save(train_program, save_path)
print("save finished") print("save finished")
dev_ppl = eval(valid_data) dev_ppl = eval(valid_data)
print("dev ppl", dev_ppl) print("dev ppl", dev_ppl)
......
...@@ -88,7 +88,8 @@ def infer(): ...@@ -88,7 +88,8 @@ def infer():
dir_name = args.reload_model dir_name = args.reload_model
print("dir name", dir_name) print("dir name", dir_name)
fluid.io.load_params(exe, dir_name) dir_name = os.path.join(dir_name, "checkpoint")
fluid.load(main_program, dir_name, exe)
vocab, tar_id2vocab = get_vocab(args.dataset_prefix) vocab, tar_id2vocab = get_vocab(args.dataset_prefix)
infer_output = np.ones((batch_size, 1), dtype='int64') * BOS_ID infer_output = np.ones((batch_size, 1), dtype='int64') * BOS_ID
......
...@@ -255,10 +255,11 @@ def main(): ...@@ -255,10 +255,11 @@ def main():
best_nll = test_nll best_nll = test_nll
best_ppl = test_ppl best_ppl = test_ppl
best_epoch_id = epoch_id best_epoch_id = epoch_id
dir_name = os.path.join(args.model_path, save_path = os.path.join(args.model_path,
"epoch_" + str(best_epoch_id)) "epoch_" + str(best_epoch_id),
print("save model {}".format(dir_name)) "checkpoint")
fluid.io.save_params(exe, dir_name, main_program) print("save model {}".format(save_path))
fluid.save(main_program, save_path)
else: else:
steps_not_improved += 1 steps_not_improved += 1
if steps_not_improved == decay_ts: if steps_not_improved == decay_ts:
......
...@@ -187,17 +187,17 @@ def do_train(args): ...@@ -187,17 +187,17 @@ def do_train(args):
end_time - start_time, train_pyreader.queue.size())) end_time - start_time, train_pyreader.queue.size()))
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.model_save_dir, save_path = os.path.join(args.model_save_dir, "step_" + str(steps),
"step_" + str(steps)) "checkpoint")
print("\tsaving model as %s" % (save_path)) print("\tsaving model as %s" % (save_path))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(train_program, save_path)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
evaluate(exe, test_program, test_pyreader, train_ret) evaluate(exe, test_program, test_pyreader, train_ret)
save_path = os.path.join(args.model_save_dir, "step_" + str(steps)) save_path = os.path.join(args.model_save_dir, "step_" + str(steps),
fluid.io.save_persistables(exe, save_path, train_program) "checkpoint")
fluid.save(train_program, save_path)
def do_eval(args): def do_eval(args):
# init executor # init executor
......
...@@ -151,8 +151,9 @@ def do_train(args): ...@@ -151,8 +151,9 @@ def do_train(args):
# save checkpoints # save checkpoints
if step % args.save_steps == 0 and step != 0: if step % args.save_steps == 0 and step != 0:
save_path = os.path.join(args.model_save_dir, save_path = os.path.join(args.model_save_dir,
"step_" + str(step)) "step_" + str(step),
fluid.io.save_persistables(exe, save_path, train_program) "checkpoint")
fluid.save(train_program, save_path)
step += 1 step += 1
if args.enable_ce: if args.enable_ce:
......
...@@ -200,22 +200,13 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -200,22 +200,13 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
assert os.path.exists( assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var): try:
""" checkpoint_path = os.path.join(init_checkpoint_path, "checkpoint")
If existed presitabels fluid.load(main_program, checkpoint_path, exe)
""" except:
if not fluid.io.is_persistable(var): fluid.load(main_program, init_checkpoint_path, exe)
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
def init_pretraining_params(exe, def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program, main_program,
...@@ -224,15 +215,6 @@ def init_pretraining_params(exe, ...@@ -224,15 +215,6 @@ def init_pretraining_params(exe,
assert os.path.exists(pretraining_params_path assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path
def _existed_params(var): fluid.load(main_program, pretraining_params_path, exe)
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=_existed_params)
print("Load pretraining parameters from {}.".format( print("Load pretraining parameters from {}.".format(
pretraining_params_path)) pretraining_params_path))
...@@ -289,8 +289,9 @@ def main(args): ...@@ -289,8 +289,9 @@ def main(args):
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps),
fluid.io.save_persistables(exe, save_path, train_program) "checkpoint")
fluid.save(train_program, save_path)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
# evaluate dev set # evaluate dev set
...@@ -301,8 +302,9 @@ def main(args): ...@@ -301,8 +302,9 @@ def main(args):
"dev") "dev")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps),
fluid.io.save_persistables(exe, save_path, train_program) "checkpoint")
fluid.save(train_program, save_path)
train_reader.reset() train_reader.reset()
break break
......
...@@ -353,8 +353,9 @@ def main(args): ...@@ -353,8 +353,9 @@ def main(args):
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps),
fluid.io.save_persistables(exe, save_path, train_program) "checkpoint")
fluid.save(train_program, save_path)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
# evaluate dev set # evaluate dev set
...@@ -364,8 +365,9 @@ def main(args): ...@@ -364,8 +365,9 @@ def main(args):
"dev") "dev")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps),
fluid.io.save_persistables(exe, save_path, train_program) "checkpoint")
fluid.save(train_program, save_path)
train_pyreader.reset() train_pyreader.reset()
break break
......
...@@ -63,20 +63,11 @@ def init_checkpoint(exe, init_checkpoint_path, main_program): ...@@ -63,20 +63,11 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
assert os.path.exists( assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
try:
def existed_persitables(var): checkpoint_path = os.path.join(init_checkpoint_path, "checkpoint")
""" fluid.load(main_program, checkpoint_path, exe)
If existed presitabels except:
""" fluid.load(main_program, init_checkpoint_path, exe)
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
...@@ -144,15 +135,6 @@ def init_pretraining_params(exe, ...@@ -144,15 +135,6 @@ def init_pretraining_params(exe,
assert os.path.exists(pretraining_params_path assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path
def _existed_params(var): fluid.load(main_program, pretraining_params_path, exe)
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=_existed_params)
print("Load pretraining parameters from {}.".format( print("Load pretraining parameters from {}.".format(
pretraining_params_path)) pretraining_params_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册