未验证 提交 4d28bb18 编写于 作者: X Xing Wu 提交者: GitHub

update save/load for v1.7 (#4332)

* update save/load for v1.7

* update load api

* update load api
上级 d0937f0b
......@@ -127,7 +127,8 @@ def infer():
dir_name = args.reload_model
print("dir name", dir_name)
fluid.io.load_params(exe, dir_name)
dir_name = os.path.join(dir_name, "checkpoint")
fluid.load(main_program, dir_name, exe)
train_data_iter = reader.get_data_iter(infer_data, 1, mode='eval')
......
......@@ -229,10 +229,11 @@ def main():
% (epoch_id, epoch_time, sum(batch_times) / len(batch_times)))
if not args.profile:
dir_name = os.path.join(args.model_path,
"epoch_" + str(epoch_id))
print("begin to save", dir_name)
fluid.io.save_params(exe, dir_name, main_program=train_program)
save_path = os.path.join(args.model_path,
"epoch_" + str(epoch_id),
"checkpoint")
print("begin to save", save_path)
fluid.save(train_program, save_path)
print("save finished")
dev_ppl = eval(valid_data)
print("dev ppl", dev_ppl)
......
......@@ -88,7 +88,8 @@ def infer():
dir_name = args.reload_model
print("dir name", dir_name)
fluid.io.load_params(exe, dir_name)
dir_name = os.path.join(dir_name, "checkpoint")
fluid.load(main_program, dir_name, exe)
vocab, tar_id2vocab = get_vocab(args.dataset_prefix)
infer_output = np.ones((batch_size, 1), dtype='int64') * BOS_ID
......
......@@ -255,10 +255,11 @@ def main():
best_nll = test_nll
best_ppl = test_ppl
best_epoch_id = epoch_id
dir_name = os.path.join(args.model_path,
"epoch_" + str(best_epoch_id))
print("save model {}".format(dir_name))
fluid.io.save_params(exe, dir_name, main_program)
save_path = os.path.join(args.model_path,
"epoch_" + str(best_epoch_id),
"checkpoint")
print("save model {}".format(save_path))
fluid.save(main_program, save_path)
else:
steps_not_improved += 1
if steps_not_improved == decay_ts:
......
......@@ -187,17 +187,17 @@ def do_train(args):
end_time - start_time, train_pyreader.queue.size()))
if steps % args.save_steps == 0:
save_path = os.path.join(args.model_save_dir,
"step_" + str(steps))
save_path = os.path.join(args.model_save_dir, "step_" + str(steps),
"checkpoint")
print("\tsaving model as %s" % (save_path))
fluid.io.save_persistables(exe, save_path, train_program)
fluid.save(train_program, save_path)
if steps % args.validation_steps == 0:
evaluate(exe, test_program, test_pyreader, train_ret)
save_path = os.path.join(args.model_save_dir, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
save_path = os.path.join(args.model_save_dir, "step_" + str(steps),
"checkpoint")
fluid.save(train_program, save_path)
def do_eval(args):
# init executor
......
......@@ -151,8 +151,9 @@ def do_train(args):
# save checkpoints
if step % args.save_steps == 0 and step != 0:
save_path = os.path.join(args.model_save_dir,
"step_" + str(step))
fluid.io.save_persistables(exe, save_path, train_program)
"step_" + str(step),
"checkpoint")
fluid.save(train_program, save_path)
step += 1
if args.enable_ce:
......
......@@ -200,22 +200,13 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
"""
If existed presitabels
"""
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
try:
checkpoint_path = os.path.join(init_checkpoint_path, "checkpoint")
fluid.load(main_program, checkpoint_path, exe)
except:
fluid.load(main_program, init_checkpoint_path, exe)
print("Load model from {}".format(init_checkpoint_path))
def init_pretraining_params(exe,
pretraining_params_path,
main_program,
......@@ -224,15 +215,6 @@ def init_pretraining_params(exe,
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def _existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=_existed_params)
fluid.load(main_program, pretraining_params_path, exe)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
......@@ -289,8 +289,9 @@ def main(args):
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
"step_" + str(steps),
"checkpoint")
fluid.save(train_program, save_path)
if steps % args.validation_steps == 0:
# evaluate dev set
......@@ -301,8 +302,9 @@ def main(args):
"dev")
except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
save_path = os.path.join(args.checkpoints, "step_" + str(steps),
"checkpoint")
fluid.save(train_program, save_path)
train_reader.reset()
break
......
......@@ -353,8 +353,9 @@ def main(args):
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
"step_" + str(steps),
"checkpoint")
fluid.save(train_program, save_path)
if steps % args.validation_steps == 0:
# evaluate dev set
......@@ -364,8 +365,9 @@ def main(args):
"dev")
except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
save_path = os.path.join(args.checkpoints, "step_" + str(steps),
"checkpoint")
fluid.save(train_program, save_path)
train_pyreader.reset()
break
......
......@@ -63,20 +63,11 @@ def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
"""
If existed presitabels
"""
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
try:
checkpoint_path = os.path.join(init_checkpoint_path, "checkpoint")
fluid.load(main_program, checkpoint_path, exe)
except:
fluid.load(main_program, init_checkpoint_path, exe)
print("Load model from {}".format(init_checkpoint_path))
......@@ -144,15 +135,6 @@ def init_pretraining_params(exe,
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def _existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=_existed_params)
fluid.load(main_program, pretraining_params_path, exe)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册