From 2746e74b71ceb7ac93679a39116e5241a9f0fa64 Mon Sep 17 00:00:00 2001 From: hutuxian Date: Thu, 14 May 2020 23:44:57 +0800 Subject: [PATCH] PaddleRec/gnn: use save/load API (#4626) --- PaddleRec/gnn/infer.py | 5 ++--- PaddleRec/gnn/train.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/PaddleRec/gnn/infer.py b/PaddleRec/gnn/infer.py index f8d1f111..787ee6f5 100644 --- a/PaddleRec/gnn/infer.py +++ b/PaddleRec/gnn/infer.py @@ -62,10 +62,9 @@ def infer(args): for epoch_num in range(args.start_index, args.last_index + 1): model_path = os.path.join(args.model_path, "epoch_" + str(epoch_num)) try: - if not os.path.exists(model_path): + if not os.path.exists(model_path + ".pdmodel"): raise ValueError() - fluid.io.load_persistables(executor=exe, dirname=model_path, - main_program=infer_program) + fluid.io.load(infer_program, model_path+".pdmodel", exe) loss_sum = 0.0 acc_sum = 0.0 diff --git a/PaddleRec/gnn/train.py b/PaddleRec/gnn/train.py index a4a1898e..f96b66de 100644 --- a/PaddleRec/gnn/train.py +++ b/PaddleRec/gnn/train.py @@ -140,7 +140,7 @@ def train(): logger.info("epoch loss: %.4lf" % (np.mean(epoch_sum))) save_dir = os.path.join(args.model_path, "epoch_" + str(i)) fetch_vars = [loss, acc] - fluid.io.save_inference_model(save_dir, feed_list, fetch_vars, exe) + fluid.save(fluid.default_main_program(), model_path=save_dir) logger.info("model saved in " + save_dir) # only for ce -- GitLab