From 682ab437a2db59134bc24dfa711e1f352f594e20 Mon Sep 17 00:00:00 2001 From: heya02 Date: Fri, 31 Jul 2020 09:43:21 +0000 Subject: [PATCH] update scripts and add param for mnist_demo.py --- .../mpc/examples/mnist_demo/mnist_demo.py | 17 ++++++++++------ .../mpc/examples/mnist_demo/run_standalone.sh | 20 ++++++++++++++----- 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py b/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py index 5a8b0a9..d1afe12 100644 --- a/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py +++ b/python/paddle_fl/mpc/examples/mnist_demo/mnist_demo.py @@ -16,6 +16,7 @@ MNIST Demo """ import sys +import os import numpy as np import time @@ -26,9 +27,11 @@ import paddle_fl.mpc as pfl_mpc import paddle_fl.mpc.data_utils.aby3 as aby3 +env_dist = os.environ +local_host= env_dist.get('LOCALHOST') role, server, port = sys.argv[1], sys.argv[2], sys.argv[3] # modify host(localhost). -pfl_mpc.init("aby3", int(role), "localhost", server, int(port)) +pfl_mpc.init("aby3", int(role), local_host, server, int(port)) role = int(role) # data preprocessing @@ -78,18 +81,20 @@ test_loader.set_batch_generator(test_batch_sample, places=place) exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) -start_time = time.time() -step = 0 for epoch_id in range(epoch_num): + start_time = time.time() + step = 0 # feed data via loader for sample in loader(): + batch_start = time.time() exe.run(feed=sample, fetch_list=[cost.name]) + batch_end = time.time() if step % 50 == 0: - print('Epoch={}, Step={}'.format(epoch_id, step)) + print('Epoch={}, Step={}, batch_cost={:.4f} s'.format(epoch_id, step, (batch_end - batch_start))) step += 1 -end_time = time.time() -print('Mpc Training of Epoch={} Batch_size={}, cost time in seconds:{}' + end_time = time.time() + print('Mpc Training of Epoch={} Batch_size={}, epoch_cost={:.4f} s' .format(epoch_num, BATCH_SIZE, (end_time - start_time))) # prediction diff --git a/python/paddle_fl/mpc/examples/mnist_demo/run_standalone.sh b/python/paddle_fl/mpc/examples/mnist_demo/run_standalone.sh index 33ce47d..fb9e7f7 100755 --- a/python/paddle_fl/mpc/examples/mnist_demo/run_standalone.sh +++ b/python/paddle_fl/mpc/examples/mnist_demo/run_standalone.sh @@ -31,12 +31,13 @@ # bash run_standalone.sh TEST_SCRIPT_NAME # -# modify the following vars according to your environment -PYTHON="python" -REDIS_HOME="path_to_redis_bin" -SERVER="localhost" -PORT=9937 +# please set the following environment vars according in your environment +PYTHON=${PYTHON} +REDIS_HOME=${PATH_TO_REDIS_BIN} +SERVER=${LOCALHOST} +PORT=${REDIS_PORT} +echo "redis home in ${REDIS_HOME}, server is ${SERVER}, port is ${PORT}" function usage() { echo 'run_standalone.sh SCRIPT_NAME [ARG...]' exit 0 @@ -67,6 +68,15 @@ if [ "$PRED_FILE" ]; then rm -rf $PRED_FILE fi +PRED_FILE="/tmp/mnist2_feature.part*" +if [ ! "$PRED_FILE" ]; then + echo "There is no data in /tmp, please prepare data with "python prepare.py" firstly" + exit 1 +else + echo "There are data for mnist:" + echo "`ls ${PRED_FILE}`" +fi + # kick off script with roles of 1 and 2, and redirect output to /dev/null for role in {1..2}; do -- GitLab