未验证 提交 68c7b046 编写于 作者: Y yaoxuefeng 提交者: GitHub

update1.7 save/load (#4298)

上级 f1e2c268
......@@ -162,8 +162,7 @@ def train():
(epoch_id + 1), time.time() - start))
if args.trainer_id == 0: # only trainer 0 save model
print("save model in {}".format(model_dir))
fluid.io.save_persistables(
executor=exe, dirname=model_dir, main_program=main_program)
fluid.save(main_program, model_dir)
print("train time cost {:.4f}".format(time.time() - start_time))
print("finish training")
......
......@@ -46,11 +46,8 @@ def infer():
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
fluid.io.load_persistables(
executor=exe,
dirname=cur_model_path,
main_program=fluid.default_main_program())
main_program = fluid.default_main_program()
fluid.load(main_program, cur_model_path, exe)
for var in auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place)
......
......@@ -61,10 +61,8 @@ def train():
'epoch_' + str(epoch_id + 1))
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
fluid.io.save_persistables(
executor=exe,
dirname=model_dir,
main_program=fluid.default_main_program())
main_program = fluid.default_main_program()
fluid.io.save(main_program, model_dir)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册