From 7a3e0c7d77eba5ee165ad994de43c495f8a3b901 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Thu, 10 Oct 2019 20:27:30 +0800 Subject: [PATCH] use new save load in all dygraph modell (#3495) * user new save load in all dygraph modell; test=develop * change load_dict to set_dict; test=develop --- dygraph/cycle_gan/infer.py | 4 ++-- dygraph/cycle_gan/test.py | 4 ++-- dygraph/cycle_gan/train.py | 2 +- dygraph/mnist/train.py | 6 +++--- dygraph/reinforcement_learning/actor_critic.py | 2 +- dygraph/reinforcement_learning/reinforce.py | 2 +- dygraph/reinforcement_learning/test_actor_critic_load.py | 4 ++-- dygraph/reinforcement_learning/test_reinforce_load.py | 4 ++-- dygraph/resnet/train.py | 2 +- dygraph/sentiment/main.py | 6 +++--- 10 files changed, 18 insertions(+), 18 deletions(-) diff --git a/dygraph/cycle_gan/infer.py b/dygraph/cycle_gan/infer.py index b0e2d45d..de70585f 100644 --- a/dygraph/cycle_gan/infer.py +++ b/dygraph/cycle_gan/infer.py @@ -52,8 +52,8 @@ def infer(): os.makedirs(out_path) cycle_gan = Cycle_Gan("cycle_gan") save_dir = args.init_model - restore, _ = fluid.dygraph.load_persistables(save_dir) - cycle_gan.load_dict(restore) + restore, _ = fluid.load_dygraph(save_dir) + cycle_gan.set_dict(restore) cycle_gan.eval() for file in glob.glob(args.input): print ("read %s" % file) diff --git a/dygraph/cycle_gan/test.py b/dygraph/cycle_gan/test.py index 19a2c7fb..ba0b03ba 100644 --- a/dygraph/cycle_gan/test.py +++ b/dygraph/cycle_gan/test.py @@ -52,8 +52,8 @@ def test(): os.makedirs(out_path) cycle_gan = Cycle_Gan("cycle_gan") save_dir = args.init_model + str(epoch) - restore, _ = fluid.dygraph.load_persistables(save_dir) - cycle_gan.load_dict(restore) + restore, _ = fluid.load_dygraph(save_dir) + cycle_gan.set_dict(restore) cycle_gan.eval() for data_A , data_B in zip(A_test_reader(), B_test_reader()): A_name = data_A[1] diff --git a/dygraph/cycle_gan/train.py b/dygraph/cycle_gan/train.py index 59dc5c0c..147b08db 100644 --- a/dygraph/cycle_gan/train.py +++ b/dygraph/cycle_gan/train.py @@ -204,7 +204,7 @@ def train(args): break if args.save_checkpoints: - fluid.dygraph.save_persistables( + fluid.save_dygraph( cycle_gan.state_dict(), args.output + "/checkpoints/{}".format(epoch)) diff --git a/dygraph/mnist/train.py b/dygraph/mnist/train.py index 5bd5e0d8..b067c94c 100644 --- a/dygraph/mnist/train.py +++ b/dygraph/mnist/train.py @@ -150,8 +150,8 @@ def inference_mnist(): with fluid.dygraph.guard(place): mnist_infer = MNIST("mnist") # load checkpoint - model_dict, _ = fluid.dygraph.load_persistables("save_dir") - mnist_infer.load_dict(model_dict) + model_dict, _ = fluid.load_dygraph("save_temp") + mnist_infer.set_dict(model_dict) print("checkpoint loaded") # start evaluate mode @@ -245,7 +245,7 @@ def train_mnist(args): args.use_data_parallel and fluid.dygraph.parallel.Env().local_rank == 0) if save_parameters: - fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir") + fluid.save_dygraph(mnist.state_dict(), "save_temp") print("checkpoint saved") inference_mnist() diff --git a/dygraph/reinforcement_learning/actor_critic.py b/dygraph/reinforcement_learning/actor_critic.py index c94b0a74..f68a53f8 100644 --- a/dygraph/reinforcement_learning/actor_critic.py +++ b/dygraph/reinforcement_learning/actor_critic.py @@ -200,5 +200,5 @@ with fluid.dygraph.guard(): print("Solved! Running reward is now {} and " "the last episode runs to {} time steps!".format( running_reward, t)) - fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir) + fluid.save_dygraph(policy.state_dict(), args.save_dir) break diff --git a/dygraph/reinforcement_learning/reinforce.py b/dygraph/reinforcement_learning/reinforce.py index 84f9ac02..2a23b345 100644 --- a/dygraph/reinforcement_learning/reinforce.py +++ b/dygraph/reinforcement_learning/reinforce.py @@ -186,5 +186,5 @@ with fluid.dygraph.guard(): print("Solved! Running reward is now {} and " "the last episode runs to {} time steps!".format( running_reward, t)) - fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir) + fluid.save_dygraph(policy.state_dict(), args.save_dir) break diff --git a/dygraph/reinforcement_learning/test_actor_critic_load.py b/dygraph/reinforcement_learning/test_actor_critic_load.py index 3f157aa9..2ddbfd8c 100644 --- a/dygraph/reinforcement_learning/test_actor_critic_load.py +++ b/dygraph/reinforcement_learning/test_actor_critic_load.py @@ -174,8 +174,8 @@ with fluid.dygraph.guard(): return returns running_reward = 10 - model_dict, _ = fluid.dygraph.load_persistables(args.save_dir) - policy.load_dict(model_dict) + model_dict, _ = fluid.load_dygraph(args.save_dir) + policy.set_dict(model_dict) state, ep_reward = env.reset(), 0 for t in range(1, 10000): # Don't infinite loop while learning diff --git a/dygraph/reinforcement_learning/test_reinforce_load.py b/dygraph/reinforcement_learning/test_reinforce_load.py index fa468e9b..db7245d1 100644 --- a/dygraph/reinforcement_learning/test_reinforce_load.py +++ b/dygraph/reinforcement_learning/test_reinforce_load.py @@ -161,8 +161,8 @@ with fluid.dygraph.guard(): running_reward = 10 state, ep_reward = env.reset(), 0 - model_dict, _ = fluid.dygraph.load_persistables(args.save_dir) - policy.load_dict(model_dict) + model_dict, _ = fluid.load_dygraph(args.save_dir) + policy.set_dict(model_dict) for t in range(1, 10000): # Don't infinite loop while learning state = np.array(state).astype("float32") diff --git a/dygraph/resnet/train.py b/dygraph/resnet/train.py index 19cda2ae..5ce1d246 100644 --- a/dygraph/resnet/train.py +++ b/dygraph/resnet/train.py @@ -380,7 +380,7 @@ def train_resnet(): args.use_data_parallel and fluid.dygraph.parallel.Env().local_rank == 0) if save_parameters: - fluid.dygraph.save_persistables(resnet.state_dict(), + fluid.save_dygraph(resnet.state_dict(), 'resnet_params') diff --git a/dygraph/sentiment/main.py b/dygraph/sentiment/main.py index 279b3477..b22f7ee7 100644 --- a/dygraph/sentiment/main.py +++ b/dygraph/sentiment/main.py @@ -232,7 +232,7 @@ def train(): if steps % args.save_steps == 0: save_path = "save_dir_" + str(steps) print('save model to: ' + save_path) - fluid.dygraph.save_persistables(cnn_net.state_dict(), + fluid.save_dygraph(cnn_net.state_dict(), save_path) if enable_profile: print('save profile result into /tmp/profile_file') @@ -258,8 +258,8 @@ def infer(): print('Do inferring ...... ') total_acc, total_num_seqs = [], [] - restore, _ = fluid.dygraph.load_persistables(args.checkpoints) - cnn_net_infer.load_dict(restore) + restore, _ = fluid.load_dygraph(args.checkpoints) + cnn_net_infer.set_dict(restore) cnn_net_infer.eval() steps = 0 -- GitLab