未验证 提交 7a3e0c7d 编写于 作者: H hong 提交者: GitHub

use new save load in all dygraph modell (#3495)

* user new save load in all dygraph modell; test=develop

* change load_dict to set_dict; test=develop
上级 7e59194b
...@@ -52,8 +52,8 @@ def infer(): ...@@ -52,8 +52,8 @@ def infer():
os.makedirs(out_path) os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan") cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model save_dir = args.init_model
restore, _ = fluid.dygraph.load_persistables(save_dir) restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.load_dict(restore) cycle_gan.set_dict(restore)
cycle_gan.eval() cycle_gan.eval()
for file in glob.glob(args.input): for file in glob.glob(args.input):
print ("read %s" % file) print ("read %s" % file)
......
...@@ -52,8 +52,8 @@ def test(): ...@@ -52,8 +52,8 @@ def test():
os.makedirs(out_path) os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan") cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model + str(epoch) save_dir = args.init_model + str(epoch)
restore, _ = fluid.dygraph.load_persistables(save_dir) restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.load_dict(restore) cycle_gan.set_dict(restore)
cycle_gan.eval() cycle_gan.eval()
for data_A , data_B in zip(A_test_reader(), B_test_reader()): for data_A , data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[1] A_name = data_A[1]
......
...@@ -204,7 +204,7 @@ def train(args): ...@@ -204,7 +204,7 @@ def train(args):
break break
if args.save_checkpoints: if args.save_checkpoints:
fluid.dygraph.save_persistables( fluid.save_dygraph(
cycle_gan.state_dict(), cycle_gan.state_dict(),
args.output + "/checkpoints/{}".format(epoch)) args.output + "/checkpoints/{}".format(epoch))
......
...@@ -150,8 +150,8 @@ def inference_mnist(): ...@@ -150,8 +150,8 @@ def inference_mnist():
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist") mnist_infer = MNIST("mnist")
# load checkpoint # load checkpoint
model_dict, _ = fluid.dygraph.load_persistables("save_dir") model_dict, _ = fluid.load_dygraph("save_temp")
mnist_infer.load_dict(model_dict) mnist_infer.set_dict(model_dict)
print("checkpoint loaded") print("checkpoint loaded")
# start evaluate mode # start evaluate mode
...@@ -245,7 +245,7 @@ def train_mnist(args): ...@@ -245,7 +245,7 @@ def train_mnist(args):
args.use_data_parallel and args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0) fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters: if save_parameters:
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir") fluid.save_dygraph(mnist.state_dict(), "save_temp")
print("checkpoint saved") print("checkpoint saved")
inference_mnist() inference_mnist()
......
...@@ -200,5 +200,5 @@ with fluid.dygraph.guard(): ...@@ -200,5 +200,5 @@ with fluid.dygraph.guard():
print("Solved! Running reward is now {} and " print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format( "the last episode runs to {} time steps!".format(
running_reward, t)) running_reward, t))
fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir) fluid.save_dygraph(policy.state_dict(), args.save_dir)
break break
...@@ -186,5 +186,5 @@ with fluid.dygraph.guard(): ...@@ -186,5 +186,5 @@ with fluid.dygraph.guard():
print("Solved! Running reward is now {} and " print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format( "the last episode runs to {} time steps!".format(
running_reward, t)) running_reward, t))
fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir) fluid.save_dygraph(policy.state_dict(), args.save_dir)
break break
...@@ -174,8 +174,8 @@ with fluid.dygraph.guard(): ...@@ -174,8 +174,8 @@ with fluid.dygraph.guard():
return returns return returns
running_reward = 10 running_reward = 10
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir) model_dict, _ = fluid.load_dygraph(args.save_dir)
policy.load_dict(model_dict) policy.set_dict(model_dict)
state, ep_reward = env.reset(), 0 state, ep_reward = env.reset(), 0
for t in range(1, 10000): # Don't infinite loop while learning for t in range(1, 10000): # Don't infinite loop while learning
......
...@@ -161,8 +161,8 @@ with fluid.dygraph.guard(): ...@@ -161,8 +161,8 @@ with fluid.dygraph.guard():
running_reward = 10 running_reward = 10
state, ep_reward = env.reset(), 0 state, ep_reward = env.reset(), 0
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir) model_dict, _ = fluid.load_dygraph(args.save_dir)
policy.load_dict(model_dict) policy.set_dict(model_dict)
for t in range(1, 10000): # Don't infinite loop while learning for t in range(1, 10000): # Don't infinite loop while learning
state = np.array(state).astype("float32") state = np.array(state).astype("float32")
......
...@@ -380,7 +380,7 @@ def train_resnet(): ...@@ -380,7 +380,7 @@ def train_resnet():
args.use_data_parallel and args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0) fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters: if save_parameters:
fluid.dygraph.save_persistables(resnet.state_dict(), fluid.save_dygraph(resnet.state_dict(),
'resnet_params') 'resnet_params')
......
...@@ -232,7 +232,7 @@ def train(): ...@@ -232,7 +232,7 @@ def train():
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = "save_dir_" + str(steps) save_path = "save_dir_" + str(steps)
print('save model to: ' + save_path) print('save model to: ' + save_path)
fluid.dygraph.save_persistables(cnn_net.state_dict(), fluid.save_dygraph(cnn_net.state_dict(),
save_path) save_path)
if enable_profile: if enable_profile:
print('save profile result into /tmp/profile_file') print('save profile result into /tmp/profile_file')
...@@ -258,8 +258,8 @@ def infer(): ...@@ -258,8 +258,8 @@ def infer():
print('Do inferring ...... ') print('Do inferring ...... ')
total_acc, total_num_seqs = [], [] total_acc, total_num_seqs = [], []
restore, _ = fluid.dygraph.load_persistables(args.checkpoints) restore, _ = fluid.load_dygraph(args.checkpoints)
cnn_net_infer.load_dict(restore) cnn_net_infer.set_dict(restore)
cnn_net_infer.eval() cnn_net_infer.eval()
steps = 0 steps = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册