提交 9f7258da 编写于 作者: P phlrain

new save load interface; test=develop

上级 54f64c66
......@@ -183,7 +183,8 @@ def convert(args):
param]).get_tensor().set(value, place)
print(param, ' --> ', tf_fluid_param_name_map[param], ' ', value.shape)
fluid.io.save_params(exe, args.fluid_params_dir, main_program=program)
save_path = os.join( args.fluid_params_dir, "checkpoint")
fluid.save( program, save_path)
if __name__ == '__main__':
......
......@@ -392,7 +392,7 @@ def main(args):
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
fluid.save( train_program, save_path )
if steps % args.validation_steps == 0:
print("Average throughtput: %s" % (np.average(throughput)))
......@@ -409,7 +409,7 @@ def main(args):
"test")
except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
fluid.save( train_program, save_path )
train_data_loader.reset()
break
if args.enable_ce:
......
......@@ -398,11 +398,11 @@ def train(args):
if steps % args.save_steps == 0 or steps == max_train_steps:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
fluid.save( train_program, save_path )
except fluid.core.EOFException:
save_path = os.path.join(args.checkpoints,
"step_" + str(steps) + "_final")
fluid.io.save_persistables(exe, save_path, train_program)
fluid.save( train_program, save_path )
train_data_loader.reset()
break
......
......@@ -38,7 +38,7 @@ from optimization import optimization
from utils.args import ArgumentGroup, print_arguments, check_cuda
from utils.init import init_checkpoint, init_pretraining_params
# yapf: disable
yapf: disable
parser = argparse.ArgumentParser(__doc__)
model_g = ArgumentGroup(parser, "model", "model configuration and paths.")
model_g.add_arg("bert_config_path", str, "./config/bert_config.json", "Path to the json file for bert model config.")
......@@ -412,7 +412,8 @@ def train(args):
if steps % args.save_steps == 0:
save_path = os.path.join(args.checkpoints, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program)
#fluid.io.save_persistables(exe, save_path, train_program)
fluid.save( train_program, save_path )
if args.validation_set_dir and steps % args.validation_steps == 0:
vali_cost, vali_lm_cost, vali_acc, vali_steps, vali_speed = predict(
......
......@@ -47,12 +47,8 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, use_fp16=False):
if os.path.exists(os.path.join(init_checkpoint_path, var.name)):
print("INIT {}".format(var.name))
return True
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
fluid.load( main_program, init_checkpoint_path, exe)
print("Load model from {}".format(init_checkpoint_path))
if use_fp16:
......@@ -63,24 +59,12 @@ def init_pretraining_params(exe,
pretraining_params_path,
main_program,
use_fp16=False):
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
assert os.path.exists(pretraining_params_path + ".params"
), "[%s] cann't be found." % (pretraining_params_path + ".params" )
def existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
if os.path.exists(os.path.join(pretraining_params_path, var.name)):
print("INIT {}".format(var.name))
return True
else:
print("SKIP {}".format(var.name))
return False
program_state = fluid.load_program_state( pretraining_params_path )
fluid.set_program_state( main_program, program_state)
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=existed_params)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册