未验证 提交 7a3e0c7d 编写于 作者: H hong 提交者: GitHub

use new save load in all dygraph modell (#3495)

* user new save load in all dygraph modell; test=develop

* change load_dict to set_dict; test=develop
上级 7e59194b
......@@ -52,8 +52,8 @@ def infer():
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model
restore, _ = fluid.dygraph.load_persistables(save_dir)
cycle_gan.load_dict(restore)
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
cycle_gan.eval()
for file in glob.glob(args.input):
print ("read %s" % file)
......
......@@ -52,8 +52,8 @@ def test():
os.makedirs(out_path)
cycle_gan = Cycle_Gan("cycle_gan")
save_dir = args.init_model + str(epoch)
restore, _ = fluid.dygraph.load_persistables(save_dir)
cycle_gan.load_dict(restore)
restore, _ = fluid.load_dygraph(save_dir)
cycle_gan.set_dict(restore)
cycle_gan.eval()
for data_A , data_B in zip(A_test_reader(), B_test_reader()):
A_name = data_A[1]
......
......@@ -204,7 +204,7 @@ def train(args):
break
if args.save_checkpoints:
fluid.dygraph.save_persistables(
fluid.save_dygraph(
cycle_gan.state_dict(),
args.output + "/checkpoints/{}".format(epoch))
......
......@@ -150,8 +150,8 @@ def inference_mnist():
with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist")
# load checkpoint
model_dict, _ = fluid.dygraph.load_persistables("save_dir")
mnist_infer.load_dict(model_dict)
model_dict, _ = fluid.load_dygraph("save_temp")
mnist_infer.set_dict(model_dict)
print("checkpoint loaded")
# start evaluate mode
......@@ -245,7 +245,7 @@ def train_mnist(args):
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.dygraph.save_persistables(mnist.state_dict(), "save_dir")
fluid.save_dygraph(mnist.state_dict(), "save_temp")
print("checkpoint saved")
inference_mnist()
......
......@@ -200,5 +200,5 @@ with fluid.dygraph.guard():
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(
running_reward, t))
fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir)
fluid.save_dygraph(policy.state_dict(), args.save_dir)
break
......@@ -186,5 +186,5 @@ with fluid.dygraph.guard():
print("Solved! Running reward is now {} and "
"the last episode runs to {} time steps!".format(
running_reward, t))
fluid.dygraph.save_persistables(policy.state_dict(), args.save_dir)
fluid.save_dygraph(policy.state_dict(), args.save_dir)
break
......@@ -174,8 +174,8 @@ with fluid.dygraph.guard():
return returns
running_reward = 10
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
model_dict, _ = fluid.load_dygraph(args.save_dir)
policy.set_dict(model_dict)
state, ep_reward = env.reset(), 0
for t in range(1, 10000): # Don't infinite loop while learning
......
......@@ -161,8 +161,8 @@ with fluid.dygraph.guard():
running_reward = 10
state, ep_reward = env.reset(), 0
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
model_dict, _ = fluid.load_dygraph(args.save_dir)
policy.set_dict(model_dict)
for t in range(1, 10000): # Don't infinite loop while learning
state = np.array(state).astype("float32")
......
......@@ -380,7 +380,7 @@ def train_resnet():
args.use_data_parallel and
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.dygraph.save_persistables(resnet.state_dict(),
fluid.save_dygraph(resnet.state_dict(),
'resnet_params')
......
......@@ -232,7 +232,7 @@ def train():
if steps % args.save_steps == 0:
save_path = "save_dir_" + str(steps)
print('save model to: ' + save_path)
fluid.dygraph.save_persistables(cnn_net.state_dict(),
fluid.save_dygraph(cnn_net.state_dict(),
save_path)
if enable_profile:
print('save profile result into /tmp/profile_file')
......@@ -258,8 +258,8 @@ def infer():
print('Do inferring ...... ')
total_acc, total_num_seqs = [], []
restore, _ = fluid.dygraph.load_persistables(args.checkpoints)
cnn_net_infer.load_dict(restore)
restore, _ = fluid.load_dygraph(args.checkpoints)
cnn_net_infer.set_dict(restore)
cnn_net_infer.eval()
steps = 0
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册