提交 c1e7c0f8 编写于 作者: H hong 提交者: Divano

fix mnist multi gpu save error (#4128)

* fix mnist multi gpu save error; test=develop

* fix inference error; test=develop

* only run inference on rankid == 0; test=develop
上级 cba182c6
......@@ -244,9 +244,10 @@ def train_mnist(args):
fluid.dygraph.parallel.Env().local_rank == 0)
if save_parameters:
fluid.save_dygraph(mnist.state_dict(), "save_temp")
print("checkpoint saved")
inference_mnist()
inference_mnist()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册