diff --git a/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py b/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py index 70f240eecc2bc893ce7c4711fd24009436d30800..5d801fc1e5a4c32167f567129a5d7ab4710d0bbd 100644 --- a/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py +++ b/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py @@ -17,6 +17,7 @@ MNIST CNN Demo (LeNet5) import sys import os +import errno import numpy as np import time import logging @@ -117,7 +118,12 @@ def infer(): """ mpc_infer_data_dir = "./mpc_infer_data/" if not os.path.exists(mpc_infer_data_dir): - os.mkdir(mpc_infer_data_dir) + try: + os.mkdir(mpc_infer_data_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file_part = prediction_file + ".part{}".format(role) diff --git a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py index ed445f2c02984530322285742766786ddba88a73..d5c39a36b9ae2e083e04ef6df067babae62eb8ae 100644 --- a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py +++ b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py @@ -17,6 +17,7 @@ MNIST Demo import sys import os +import errno import numpy as np import time @@ -99,9 +100,15 @@ print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s' .format(epoch_num, BATCH_SIZE, (end_time - start_time))) # prediction + mpc_infer_data_dir = "./mpc_infer_data/" if not os.path.exists(mpc_infer_data_dir): - os.mkdir(mpc_infer_data_dir) + try: + os.mkdir(mpc_infer_data_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + prediction_file = mpc_infer_data_dir + "mnist_debug_prediction.part{}".format(role) if os.path.exists(prediction_file): os.remove(prediction_file) diff --git a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py index 457ce963d45da6957741b1a79e3d41fa2821dd6b..1e1fd7b5458ea534030370fff83651904832a3bd 100644 --- a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py +++ b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py @@ -17,6 +17,8 @@ MNIST CNN Demo (LeNet5) import sys import os +import errno + import numpy as np import time import logging @@ -91,7 +93,12 @@ def infer(): """ mpc_infer_data_dir = "./mpc_infer_data/" if not os.path.exists(mpc_infer_data_dir): - os.mkdir(mpc_infer_data_dir) + try: + os.mkdir(mpc_infer_data_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file_part = prediction_file + ".part{}".format(role)