未验证 提交 66e135cc 编写于 作者: J Jiabin Yang 提交者: GitHub

using new load api (#2372)

上级 623698ef
......@@ -148,7 +148,8 @@ def inference_mnist():
with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(fluid.dygraph.load_persistables("save_dir"))
model_dict, _ = fluid.dygraph.load_persistables("save_dir")
mnist_infer.load_dict(model_dict)
print("checkpoint loaded")
# start evaluate mode
......
......@@ -174,7 +174,8 @@ with fluid.dygraph.guard():
return returns
running_reward = 10
policy.load_dict(fluid.dygraph.load_persistables(args.save_dir))
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
state, ep_reward = env.reset(), 0
for t in range(1, 10000): # Don't infinite loop while learning
......
......@@ -161,7 +161,8 @@ with fluid.dygraph.guard():
running_reward = 10
state, ep_reward = env.reset(), 0
policy.load_dict(fluid.dygraph.load_persistables(args.save_dir))
model_dict, _ = fluid.dygraph.load_persistables(args.save_dir)
policy.load_dict(model_dict)
for t in range(1, 10000): # Don't infinite loop while learning
state = np.array(state).astype("float32")
......
......@@ -231,7 +231,7 @@ def infer():
print('Do inferring ...... ')
total_acc, total_num_seqs = [], []
restore = fluid.dygraph.load_persistables(args.checkpoints)
restore, _ = fluid.dygraph.load_persistables(args.checkpoints)
cnn_net_infer.load_dict(restore)
cnn_net_infer.eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册