未验证 提交 018a33dd 编写于 作者: Y Yibing Liu 提交者: GitHub

Use save & load api for bert (#4320)

上级 b93b501b
...@@ -158,7 +158,7 @@ def optimization(loss, ...@@ -158,7 +158,7 @@ def optimization(loss,
else: else:
if weight_decay > 0: if weight_decay > 0:
for param in train_program.global_block().all_parameters(): for param in train_program.all_parameters():
param_list[param.name] = param * 1.0 param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True param_list[param.name].stop_gradient = True
......
...@@ -392,7 +392,7 @@ def main(args): ...@@ -392,7 +392,7 @@ def main(args):
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(program=train_program, model_path=save_path)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
print("Average throughtput: %s" % (np.average(throughput))) print("Average throughtput: %s" % (np.average(throughput)))
...@@ -409,7 +409,7 @@ def main(args): ...@@ -409,7 +409,7 @@ def main(args):
"test") "test")
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(program=train_program, model_path=save_path)
train_data_loader.reset() train_data_loader.reset()
break break
if args.enable_ce: if args.enable_ce:
......
...@@ -398,11 +398,11 @@ def train(args): ...@@ -398,11 +398,11 @@ def train(args):
if steps % args.save_steps == 0 or steps == max_train_steps: if steps % args.save_steps == 0 or steps == max_train_steps:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps)) "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(program=train_program, model_path=save_path)
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, save_path = os.path.join(args.checkpoints,
"step_" + str(steps) + "_final") "step_" + str(steps) + "_final")
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(program=train_program, model_path=save_path)
train_data_loader.reset() train_data_loader.reset()
break break
......
...@@ -412,7 +412,7 @@ def train(args): ...@@ -412,7 +412,7 @@ def train(args):
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, "step_" + str(steps)) save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(program=train_program, model_path=save_path)
if args.validation_set_dir and steps % args.validation_steps == 0: if args.validation_set_dir and steps % args.validation_steps == 0:
vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict( vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict(
......
...@@ -25,7 +25,7 @@ import paddle.fluid as fluid ...@@ -25,7 +25,7 @@ import paddle.fluid as fluid
def cast_fp32_to_fp16(exe, main_program): def cast_fp32_to_fp16(exe, main_program):
print("Cast parameters to float16 data format.") print("Cast parameters to float16 data format.")
for param in main_program.global_block().all_parameters(): for param in main_program.all_parameters():
if not param.name.endswith(".master"): if not param.name.endswith(".master"):
param_t = fluid.global_scope().find_var(param.name).get_tensor() param_t = fluid.global_scope().find_var(param.name).get_tensor()
data = np.array(param_t) data = np.array(param_t)
...@@ -38,21 +38,9 @@ def cast_fp32_to_fp16(exe, main_program): ...@@ -38,21 +38,9 @@ def cast_fp32_to_fp16(exe, main_program):
def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False): def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
assert os.path.exists( fluid.load(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path program=main_program, model_path=init_checkpoint_path, executor=exe)
def existed_persitables(var):
if not fluid.io.is_persistable(var):
return False
if os.path.exists(os.path.join(init_checkpoint_path, var.name)):
print("INIT {}".format(var.name))
return True
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path)) print("Load model from {}".format(init_checkpoint_path))
if use_fp16: if use_fp16:
...@@ -63,24 +51,8 @@ def init_pretraining_params(exe, ...@@ -63,24 +51,8 @@ def init_pretraining_params(exe,
pretraining_params_path, pretraining_params_path,
main_program, main_program,
use_fp16=False): use_fp16=False):
assert os.path.exists(pretraining_params_path fluid.load(
), "[%s] cann't be found." % pretraining_params_path program=main_program, model_path=pretraining_params_path, executor=exe)
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
if os.path.exists(os.path.join(pretraining_params_path, var.name)):
print("INIT {}".format(var.name))
return True
else:
print("SKIP {}".format(var.name))
return False
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
print("Load pretraining parameters from {}.".format( print("Load pretraining parameters from {}.".format(
pretraining_params_path)) pretraining_params_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册