未验证 提交 f07c9635 编写于 作者: C chengduo 提交者: GitHub

update mnist dygraph (#2320)

上级 fe7043f9
......@@ -143,7 +143,9 @@ def test_mnist(reader, model, batch_size):
def inference_mnist():
with fluid.dygraph.guard():
place = fluid.CUDAPlace(fluid.dygraph.parallel.Env().dev_id) \
if args.use_data_parallel else fluid.CUDAPlace(0)
with fluid.dygraph.guard(place):
mnist_infer = MNIST("mnist")
# load checkpoint
mnist_infer.load_dict(fluid.dygraph.load_persistables("save_dir"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册