提交 b9a37f56 编写于 作者: J jedhong

Fix dir creation race condition bug in examples

上级 24953638
...@@ -17,6 +17,7 @@ MNIST CNN Demo (LeNet5) ...@@ -17,6 +17,7 @@ MNIST CNN Demo (LeNet5)
import sys import sys
import os import os
import errno
import numpy as np import numpy as np
import time import time
import logging import logging
...@@ -117,7 +118,12 @@ def infer(): ...@@ -117,7 +118,12 @@ def infer():
""" """
mpc_infer_data_dir = "./mpc_infer_data/" mpc_infer_data_dir = "./mpc_infer_data/"
if not os.path.exists(mpc_infer_data_dir): if not os.path.exists(mpc_infer_data_dir):
os.mkdir(mpc_infer_data_dir) try:
os.mkdir(mpc_infer_data_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file = mpc_infer_data_dir + "mnist_debug_prediction"
prediction_file_part = prediction_file + ".part{}".format(role) prediction_file_part = prediction_file + ".part{}".format(role)
......
...@@ -17,6 +17,7 @@ MNIST Demo ...@@ -17,6 +17,7 @@ MNIST Demo
import sys import sys
import os import os
import errno
import numpy as np import numpy as np
import time import time
...@@ -99,9 +100,15 @@ print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s' ...@@ -99,9 +100,15 @@ print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s'
.format(epoch_num, BATCH_SIZE, (end_time - start_time))) .format(epoch_num, BATCH_SIZE, (end_time - start_time)))
# prediction # prediction
mpc_infer_data_dir = "./mpc_infer_data/" mpc_infer_data_dir = "./mpc_infer_data/"
if not os.path.exists(mpc_infer_data_dir): if not os.path.exists(mpc_infer_data_dir):
os.mkdir(mpc_infer_data_dir) try:
os.mkdir(mpc_infer_data_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
prediction_file = mpc_infer_data_dir + "mnist_debug_prediction.part{}".format(role) prediction_file = mpc_infer_data_dir + "mnist_debug_prediction.part{}".format(role)
if os.path.exists(prediction_file): if os.path.exists(prediction_file):
os.remove(prediction_file) os.remove(prediction_file)
......
...@@ -17,6 +17,8 @@ MNIST CNN Demo (LeNet5) ...@@ -17,6 +17,8 @@ MNIST CNN Demo (LeNet5)
import sys import sys
import os import os
import errno
import numpy as np import numpy as np
import time import time
import logging import logging
...@@ -91,7 +93,12 @@ def infer(): ...@@ -91,7 +93,12 @@ def infer():
""" """
mpc_infer_data_dir = "./mpc_infer_data/" mpc_infer_data_dir = "./mpc_infer_data/"
if not os.path.exists(mpc_infer_data_dir): if not os.path.exists(mpc_infer_data_dir):
os.mkdir(mpc_infer_data_dir) try:
os.mkdir(mpc_infer_data_dir)
except OSError as e:
if e.errno != errno.EEXIST:
raise
prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file = mpc_infer_data_dir + "mnist_debug_prediction"
prediction_file_part = prediction_file + ".part{}".format(role) prediction_file_part = prediction_file + ".part{}".format(role)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册