未验证 提交 13cba8d8 编写于 作者: Q Qinghe JING 提交者: GitHub

Merge pull request #110 from honshj/master

Fix dir creation race condition bug in examples
...@@ -17,6 +17,7 @@ MNIST CNN Demo (LeNet5) ...@@ -17,6 +17,7 @@ MNIST CNN Demo (LeNet5)
import sys import sys
import os import os
import errno
import numpy as np import numpy as np
import time import time
import logging import logging
...@@ -117,7 +118,12 @@ def infer(): ...@@ -117,7 +118,12 @@ def infer():
""" """
mpc_infer_data_dir = "./mpc_infer_data/" mpc_infer_data_dir = "./mpc_infer_data/"
if not os.path.exists(mpc_infer_data_dir): if not os.path.exists(mpc_infer_data_dir):
os.mkdir(mpc_infer_data_dir) try:
os.mkdir(mpc_infer_data_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file = mpc_infer_data_dir + "mnist_debug_prediction"
prediction_file_part = prediction_file + ".part{}".format(role) prediction_file_part = prediction_file + ".part{}".format(role)
......
...@@ -17,6 +17,7 @@ MNIST Demo ...@@ -17,6 +17,7 @@ MNIST Demo
import sys import sys
import os import os
import errno
import numpy as np import numpy as np
import time import time
...@@ -99,9 +100,15 @@ print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s' ...@@ -99,9 +100,15 @@ print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s'
.format(epoch_num, BATCH_SIZE, (end_time - start_time))) .format(epoch_num, BATCH_SIZE, (end_time - start_time)))
# prediction # prediction
mpc_infer_data_dir = "./mpc_infer_data/" mpc_infer_data_dir = "./mpc_infer_data/"
if not os.path.exists(mpc_infer_data_dir): if not os.path.exists(mpc_infer_data_dir):
os.mkdir(mpc_infer_data_dir) try:
os.mkdir(mpc_infer_data_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
prediction_file = mpc_infer_data_dir + "mnist_debug_prediction.part{}".format(role) prediction_file = mpc_infer_data_dir + "mnist_debug_prediction.part{}".format(role)
if os.path.exists(prediction_file): if os.path.exists(prediction_file):
os.remove(prediction_file) os.remove(prediction_file)
......
...@@ -17,6 +17,8 @@ MNIST CNN Demo (LeNet5) ...@@ -17,6 +17,8 @@ MNIST CNN Demo (LeNet5)
import sys import sys
import os import os
import errno
import numpy as np import numpy as np
import time import time
import logging import logging
...@@ -91,7 +93,12 @@ def infer(): ...@@ -91,7 +93,12 @@ def infer():
""" """
mpc_infer_data_dir = "./mpc_infer_data/" mpc_infer_data_dir = "./mpc_infer_data/"
if not os.path.exists(mpc_infer_data_dir): if not os.path.exists(mpc_infer_data_dir):
os.mkdir(mpc_infer_data_dir) try:
os.mkdir(mpc_infer_data_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file = mpc_infer_data_dir + "mnist_debug_prediction"
prediction_file_part = prediction_file + ".part{}".format(role) prediction_file_part = prediction_file + ".part{}".format(role)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册