未验证 提交 b93b501b 编写于 作者: X xujiaqi01 提交者: GitHub

upgrade save and load interface (#4311) (#4319)

* upgrade dcn and xdeepfm, change from all old save/load api to fluid.save and fluid.load
* test=develop
上级 b9bd9ed0
...@@ -45,7 +45,7 @@ def infer(): ...@@ -45,7 +45,7 @@ def infer():
startup_program = fluid.framework.Program() startup_program = fluid.framework.Program()
test_program = fluid.framework.Program() test_program = fluid.framework.Program()
cur_model_path = os.path.join(args.model_output_dir, cur_model_path = os.path.join(args.model_output_dir,
'epoch_' + args.test_epoch) 'epoch_' + args.test_epoch, "checkpoint")
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
with fluid.framework.program_guard(test_program, startup_program): with fluid.framework.program_guard(test_program, startup_program):
...@@ -62,10 +62,9 @@ def infer(): ...@@ -62,10 +62,9 @@ def infer():
exe = fluid.Executor(place) exe = fluid.Executor(place)
feeder = fluid.DataFeeder( feeder = fluid.DataFeeder(
feed_list=dcn_model.data_list, place=place) feed_list=dcn_model.data_list, place=place)
fluid.io.load_persistables(
executor=exe, exe.run(startup_program)
dirname=cur_model_path, fluid.io.load(fluid.default_main_program(), cur_model_path)
main_program=fluid.default_main_program())
for var in dcn_model.auc_states: # reset auc states for var in dcn_model.auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place) set_zero(var.name, scope=inference_scope, place=place)
......
...@@ -80,13 +80,10 @@ def train(args): ...@@ -80,13 +80,10 @@ def train(args):
debug=False, debug=False,
print_period=args.print_steps) print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir, model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1)) 'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % ( sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start)) (epoch_id + 1), time.time() - start))
fluid.io.save_persistables( fluid.save(fluid.default_main_program(), model_dir)
executor=exe,
dirname=model_dir,
main_program=fluid.default_main_program())
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -36,7 +36,7 @@ def infer(): ...@@ -36,7 +36,7 @@ def infer():
startup_program = fluid.framework.Program() startup_program = fluid.framework.Program()
test_program = fluid.framework.Program() test_program = fluid.framework.Program()
cur_model_path = os.path.join(args.model_output_dir, cur_model_path = os.path.join(args.model_output_dir,
'epoch_' + args.test_epoch) 'epoch_' + args.test_epoch, "checkpoint")
with fluid.scope_guard(inference_scope): with fluid.scope_guard(inference_scope):
with fluid.framework.program_guard(test_program, startup_program): with fluid.framework.program_guard(test_program, startup_program):
...@@ -48,10 +48,9 @@ def infer(): ...@@ -48,10 +48,9 @@ def infer():
exe = fluid.Executor(place) exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=data_list, place=place) feeder = fluid.DataFeeder(feed_list=data_list, place=place)
fluid.io.load_persistables(
executor=exe, exe.run(startup_program)
dirname=cur_model_path, fluid.io.load(fluid.default_main_program(), cur_model_path)
main_program=fluid.default_main_program())
for var in auc_states: # reset auc states for var in auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place) set_zero(var.name, scope=inference_scope, place=place)
......
...@@ -55,13 +55,10 @@ def train(): ...@@ -55,13 +55,10 @@ def train():
debug=False, debug=False,
print_period=args.print_steps) print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir, model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1)) 'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % ( sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start)) (epoch_id + 1), time.time() - start))
fluid.io.save_persistables( fluid.io.save_persistables(fluid.default_main_program(), model_dir)
executor=exe,
dirname=model_dir,
main_program=fluid.default_main_program())
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册