未验证 提交 66e135cc 编写于 作者: J Jiabin Yang 提交者: GitHub

using new load api (#2372)

上级 623698ef
...@@ -148,7 +148,8 @@ def inference_mnist(): ...@@ -148,7 +148,8 @@ def inference_mnist():
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist") mnist_infer = MNIST("mnist")
# load checkpoint # load checkpoint
mnist_infer.load_dict(fluid.dygraph.load_persistables("save_dir")) model_dict, _ = fluid.dygraph.load_persistables("save_dir")
mnist_infer.load_dict(model_dict)
print("checkpoint loaded") print("checkpoint loaded")
# start evaluate mode # start evaluate mode
......
...@@ -174,7 +174,8 @@ with fluid.dygraph.guard(): ...@@ -174,7 +174,8 @@ with fluid.dygraph.guard():
return returns return returns
running_reward = 10 running_reward = 10
policy.load_dict(fluid.dygraph.load_persistables(args.save_dir)) model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
state, ep_reward = env.reset(), 0 state, ep_reward = env.reset(), 0
for t in range(1, 10000): # Don't infinite loop while learning for t in range(1, 10000): # Don't infinite loop while learning
......
...@@ -161,7 +161,8 @@ with fluid.dygraph.guard(): ...@@ -161,7 +161,8 @@ with fluid.dygraph.guard():
running_reward = 10 running_reward = 10
state, ep_reward = env.reset(), 0 state, ep_reward = env.reset(), 0
policy.load_dict(fluid.dygraph.load_persistables(args.save_dir)) model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
for t in range(1, 10000): # Don't infinite loop while learning for t in range(1, 10000): # Don't infinite loop while learning
state = np.array(state).astype("float32") state = np.array(state).astype("float32")
......
...@@ -231,7 +231,7 @@ def infer(): ...@@ -231,7 +231,7 @@ def infer():
print('Do inferring ...... ') print('Do inferring ...... ')
total_acc, total_num_seqs = [], [] total_acc, total_num_seqs = [], []
restore = fluid.dygraph.load_persistables(args.checkpoints) restore, _ = fluid.dygraph.load_persistables(args.checkpoints)
cnn_net_infer.load_dict(restore) cnn_net_infer.load_dict(restore)
cnn_net_infer.eval() cnn_net_infer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册