未验证 提交 2746e74b 编写于 作者: H hutuxian 提交者: GitHub

PaddleRec/gnn: use save/load API (#4626)

上级 406187e0
......@@ -62,10 +62,9 @@ def infer(args):
for epoch_num in range(args.start_index, args.last_index + 1):
model_path = os.path.join(args.model_path, "epoch_" + str(epoch_num))
try:
if not os.path.exists(model_path):
if not os.path.exists(model_path + ".pdmodel"):
raise ValueError()
fluid.io.load_persistables(executor=exe, dirname=model_path,
main_program=infer_program)
fluid.io.load(infer_program, model_path+".pdmodel", exe)
loss_sum = 0.0
acc_sum = 0.0
......
......@@ -140,7 +140,7 @@ def train():
logger.info("epoch loss: %.4lf" % (np.mean(epoch_sum)))
save_dir = os.path.join(args.model_path, "epoch_" + str(i))
fetch_vars = [loss, acc]
fluid.io.save_inference_model(save_dir, feed_list, fetch_vars, exe)
fluid.save(fluid.default_main_program(), model_path=save_dir)
logger.info("model saved in " + save_dir)
# only for ce
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册