From b9a37f56fcc3b564267347a3dc767e30264665b0 Mon Sep 17 00:00:00 2001 From: jedhong Date: Fri, 28 Aug 2020 08:14:04 +0000 Subject: [PATCH] Fix dir creation race condition bug in examples --- .../mpc/examples/lenet_with_mnist/train_lenet.py | 8 +++++++- .../mpc/examples/logistic_with_mnist/train_fc_sigmoid.py | 9 ++++++++- .../mpc/examples/logistic_with_mnist/train_fc_softmax.py | 9 ++++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py b/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py index 70f240e..5d801fc 100644 --- a/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py +++ b/python/paddle_fl/mpc/examples/lenet_with_mnist/train_lenet.py @@ -17,6 +17,7 @@ MNIST CNN Demo (LeNet5) import sys import os +import errno import numpy as np import time import logging @@ -117,7 +118,12 @@ def infer(): """ mpc_infer_data_dir = "./mpc_infer_data/" if not os.path.exists(mpc_infer_data_dir): - os.mkdir(mpc_infer_data_dir) + try: + os.mkdir(mpc_infer_data_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file_part = prediction_file + ".part{}".format(role) diff --git a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py index ed445f2..d5c39a3 100644 --- a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py +++ b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_sigmoid.py @@ -17,6 +17,7 @@ MNIST Demo import sys import os +import errno import numpy as np import time @@ -99,9 +100,15 @@ print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s' .format(epoch_num, BATCH_SIZE, (end_time - start_time))) # prediction + mpc_infer_data_dir = "./mpc_infer_data/" if not os.path.exists(mpc_infer_data_dir): - os.mkdir(mpc_infer_data_dir) + try: + os.mkdir(mpc_infer_data_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + prediction_file = mpc_infer_data_dir + "mnist_debug_prediction.part{}".format(role) if os.path.exists(prediction_file): os.remove(prediction_file) diff --git a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py index 457ce96..1e1fd7b 100644 --- a/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py +++ b/python/paddle_fl/mpc/examples/logistic_with_mnist/train_fc_softmax.py @@ -17,6 +17,8 @@ MNIST CNN Demo (LeNet5) import sys import os +import errno + import numpy as np import time import logging @@ -91,7 +93,12 @@ def infer(): """ mpc_infer_data_dir = "./mpc_infer_data/" if not os.path.exists(mpc_infer_data_dir): - os.mkdir(mpc_infer_data_dir) + try: + os.mkdir(mpc_infer_data_dir) + except OSError as e: + if e.errno != errno.EEXIST: + raise + prediction_file = mpc_infer_data_dir + "mnist_debug_prediction" prediction_file_part = prediction_file + ".part{}".format(role) -- GitLab