From 68c7b0468755f94274ee8332c998ecc055b41f44 Mon Sep 17 00:00:00 2001 From: yaoxuefeng Date: Mon, 17 Feb 2020 18:17:46 +0800 Subject: [PATCH] update1.7 save/load (#4298) --- PaddleRec/ctr/deepfm/cluster_train.py | 3 +-- PaddleRec/ctr/deepfm/infer.py | 7 ++----- PaddleRec/ctr/deepfm/local_train.py | 6 ++---- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/PaddleRec/ctr/deepfm/cluster_train.py b/PaddleRec/ctr/deepfm/cluster_train.py index c0509d46..da565172 100644 --- a/PaddleRec/ctr/deepfm/cluster_train.py +++ b/PaddleRec/ctr/deepfm/cluster_train.py @@ -162,8 +162,7 @@ def train(): (epoch_id + 1), time.time() - start)) if args.trainer_id == 0: # only trainer 0 save model print("save model in {}".format(model_dir)) - fluid.io.save_persistables( - executor=exe, dirname=model_dir, main_program=main_program) + fluid.save(main_program, model_dir) print("train time cost {:.4f}".format(time.time() - start_time)) print("finish training") diff --git a/PaddleRec/ctr/deepfm/infer.py b/PaddleRec/ctr/deepfm/infer.py index 2b7e29a7..9ff58af7 100644 --- a/PaddleRec/ctr/deepfm/infer.py +++ b/PaddleRec/ctr/deepfm/infer.py @@ -46,11 +46,8 @@ def infer(): exe = fluid.Executor(place) feeder = fluid.DataFeeder(feed_list=data_list, place=place) - fluid.io.load_persistables( - executor=exe, - dirname=cur_model_path, - main_program=fluid.default_main_program()) - + main_program = fluid.default_main_program() + fluid.load(main_program, cur_model_path, exe) for var in auc_states: # reset auc states set_zero(var.name, scope=inference_scope, place=place) diff --git a/PaddleRec/ctr/deepfm/local_train.py b/PaddleRec/ctr/deepfm/local_train.py index 001f625e..a0894c9b 100644 --- a/PaddleRec/ctr/deepfm/local_train.py +++ b/PaddleRec/ctr/deepfm/local_train.py @@ -61,10 +61,8 @@ def train(): 'epoch_' + str(epoch_id + 1)) sys.stderr.write('epoch%d is finished and takes %f s\n' % ( (epoch_id + 1), time.time() - start)) - fluid.io.save_persistables( - executor=exe, - dirname=model_dir, - main_program=fluid.default_main_program()) + main_program = fluid.default_main_program() + fluid.io.save(main_program, model_dir) if __name__ == '__main__': -- GitLab