From b93b501b72d1d911bcbb7a493d7864d2d3244f37 Mon Sep 17 00:00:00 2001 From: xujiaqi01 <173596896@qq.com> Date: Thu, 20 Feb 2020 16:37:57 +0800 Subject: [PATCH] upgrade save and load interface (#4311) (#4319) * upgrade dcn and xdeepfm, change from all old save/load api to fluid.save and fluid.load * test=develop --- PaddleRec/ctr/dcn/infer.py | 9 ++++----- PaddleRec/ctr/dcn/local_train.py | 7 ++----- PaddleRec/ctr/xdeepfm/infer.py | 9 ++++----- PaddleRec/ctr/xdeepfm/local_train.py | 7 ++----- 4 files changed, 12 insertions(+), 20 deletions(-) diff --git a/PaddleRec/ctr/dcn/infer.py b/PaddleRec/ctr/dcn/infer.py index 260bbe90..83e49e0b 100644 --- a/PaddleRec/ctr/dcn/infer.py +++ b/PaddleRec/ctr/dcn/infer.py @@ -45,7 +45,7 @@ def infer(): startup_program = fluid.framework.Program() test_program = fluid.framework.Program() cur_model_path = os.path.join(args.model_output_dir, - 'epoch_' + args.test_epoch) + 'epoch_' + args.test_epoch, "checkpoint") with fluid.scope_guard(inference_scope): with fluid.framework.program_guard(test_program, startup_program): @@ -62,10 +62,9 @@ def infer(): exe = fluid.Executor(place) feeder = fluid.DataFeeder( feed_list=dcn_model.data_list, place=place) - fluid.io.load_persistables( - executor=exe, - dirname=cur_model_path, - main_program=fluid.default_main_program()) + + exe.run(startup_program) + fluid.io.load(fluid.default_main_program(), cur_model_path) for var in dcn_model.auc_states: # reset auc states set_zero(var.name, scope=inference_scope, place=place) diff --git a/PaddleRec/ctr/dcn/local_train.py b/PaddleRec/ctr/dcn/local_train.py index d01d702f..fd807e45 100644 --- a/PaddleRec/ctr/dcn/local_train.py +++ b/PaddleRec/ctr/dcn/local_train.py @@ -80,13 +80,10 @@ def train(args): debug=False, print_period=args.print_steps) model_dir = os.path.join(args.model_output_dir, - 'epoch_' + str(epoch_id + 1)) + 'epoch_' + str(epoch_id + 1), "checkpoint") sys.stderr.write('epoch%d is finished and takes %f s\n' % ( (epoch_id + 1), time.time() - start)) - fluid.io.save_persistables( - executor=exe, - dirname=model_dir, - main_program=fluid.default_main_program()) + fluid.save(fluid.default_main_program(), model_dir) if __name__ == '__main__': diff --git a/PaddleRec/ctr/xdeepfm/infer.py b/PaddleRec/ctr/xdeepfm/infer.py index 489d6337..1c6277ad 100644 --- a/PaddleRec/ctr/xdeepfm/infer.py +++ b/PaddleRec/ctr/xdeepfm/infer.py @@ -36,7 +36,7 @@ def infer(): startup_program = fluid.framework.Program() test_program = fluid.framework.Program() cur_model_path = os.path.join(args.model_output_dir, - 'epoch_' + args.test_epoch) + 'epoch_' + args.test_epoch, "checkpoint") with fluid.scope_guard(inference_scope): with fluid.framework.program_guard(test_program, startup_program): @@ -48,10 +48,9 @@ def infer(): exe = fluid.Executor(place) feeder = fluid.DataFeeder(feed_list=data_list, place=place) - fluid.io.load_persistables( - executor=exe, - dirname=cur_model_path, - main_program=fluid.default_main_program()) + + exe.run(startup_program) + fluid.io.load(fluid.default_main_program(), cur_model_path) for var in auc_states: # reset auc states set_zero(var.name, scope=inference_scope, place=place) diff --git a/PaddleRec/ctr/xdeepfm/local_train.py b/PaddleRec/ctr/xdeepfm/local_train.py index c5fc8a2e..2bb7e1de 100644 --- a/PaddleRec/ctr/xdeepfm/local_train.py +++ b/PaddleRec/ctr/xdeepfm/local_train.py @@ -55,13 +55,10 @@ def train(): debug=False, print_period=args.print_steps) model_dir = os.path.join(args.model_output_dir, - 'epoch_' + str(epoch_id + 1)) + 'epoch_' + str(epoch_id + 1), "checkpoint") sys.stderr.write('epoch%d is finished and takes %f s\n' % ( (epoch_id + 1), time.time() - start)) - fluid.io.save_persistables( - executor=exe, - dirname=model_dir, - main_program=fluid.default_main_program()) + fluid.io.save_persistables(fluid.default_main_program(), model_dir) if __name__ == '__main__': -- GitLab