未验证 提交 4d9b5e57 编写于 作者: X xujiaqi01 提交者: GitHub

fix save error of xdeepfm (#4507)

* fix save error of xdeepfm
* test=develop
上级 f2c1ecd9
......@@ -169,13 +169,12 @@ def train():
debug=False,
print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1))
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
if args.trainer_id == 0: # only trainer 0 save model
print("save model in {}".format(model_dir))
fluid.io.save_persistables(
executor=exe, dirname=model_dir, main_program=main_program)
fluid.save(main_program, model_dir)
print("train time cost {:.4f}".format(time.time() - start_time))
print("finish training")
......
......@@ -64,7 +64,7 @@ def infer():
feed_list=dcn_model.data_list, place=place)
exe.run(startup_program)
fluid.io.load(fluid.default_main_program(), cur_model_path)
fluid.load(fluid.default_main_program(), cur_model_path)
for var in dcn_model.auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place)
......
......@@ -162,13 +162,12 @@ def train():
debug=False,
print_period=args.print_steps)
model_dir = os.path.join(args.model_output_dir,
'epoch_' + str(epoch_id + 1))
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
if args.trainer_id == 0: # only trainer 0 save model
print("save model in {}".format(model_dir))
fluid.io.save_persistables(
executor=exe, dirname=model_dir, main_program=main_program)
fluid.save(main_program, model_dir)
print("train time cost {:.4f}".format(time.time() - start_time))
print("finish training")
......
......@@ -50,7 +50,7 @@ def infer():
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe.run(startup_program)
fluid.io.load(fluid.default_main_program(), cur_model_path)
fluid.load(fluid.default_main_program(), cur_model_path)
for var in auc_states: # reset auc states
set_zero(var.name, scope=inference_scope, place=place)
......
......@@ -58,7 +58,7 @@ def train():
'epoch_' + str(epoch_id + 1), "checkpoint")
sys.stderr.write('epoch%d is finished and takes %f s\n' % (
(epoch_id + 1), time.time() - start))
fluid.io.save_persistables(fluid.default_main_program(), model_dir)
fluid.save(fluid.default_main_program(), model_dir)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册