未验证 提交 b93b501b 编写于 作者: X xujiaqi01 提交者: GitHub

upgrade save and load interface (#4311) (#4319)

* upgrade dcn and xdeepfm, change from all old save/load api to fluid.save and fluid.load
* test=develop
上级 b9bd9ed0
......@@ -45,7 +45,7 @@ def infer():
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
cur_model_path = os.path.join(args.model_output_dir,
'epoch_' + args.test_epoch)
'epoch_' + args.test_epoch, "checkpoint")
with fluid.scope_guard(inference_scope):
with fluid.framework.program_guard(test_program, startup_program):
......@@ -62,10 +62,9 @@ def infer():
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(
feed_list=dcn_model.data_list, place=place)
fluid.io.load_persistables(
executor=exe,
dirname=cur_model_path,
main_program=fluid.default_main_program())
exe.run(startup_program)
fluid.io.load(fluid.default_main_program(), cur_model_path)
for var in dcn_model.auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place)
......
......@@ -80,13 +80,10 @@ def train(args):
debug=False,
print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1))
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
fluid.io.save_persistables(
executor=exe,
dirname=model_dir,
main_program=fluid.default_main_program())
fluid.save(fluid.default_main_program(), model_dir)
if __name__ == '__main__':
......
......@@ -36,7 +36,7 @@ def infer():
startup_program = fluid.framework.Program()
test_program = fluid.framework.Program()
cur_model_path = os.path.join(args.model_output_dir,
'epoch_' + args.test_epoch)
'epoch_' + args.test_epoch, "checkpoint")
with fluid.scope_guard(inference_scope):
with fluid.framework.program_guard(test_program, startup_program):
......@@ -48,10 +48,9 @@ def infer():
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
fluid.io.load_persistables(
executor=exe,
dirname=cur_model_path,
main_program=fluid.default_main_program())
exe.run(startup_program)
fluid.io.load(fluid.default_main_program(), cur_model_path)
for var in auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place)
......
......@@ -55,13 +55,10 @@ def train():
debug=False,
print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1))
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
fluid.io.save_persistables(
executor=exe,
dirname=model_dir,
main_program=fluid.default_main_program())
fluid.io.save_persistables(fluid.default_main_program(), model_dir)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册