提交 7c86e9c9 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Add SyncReplicasOptimizer test in dist_test

Usage example: ./remote_test.sh --num-workers 3 --sync-replicas

Also changed:
    1) In local and remote tests, let different workers contact separate GRPC
    sessions.
    2) In local and remote tests, adding the capacity to specify the number of
    workers. Before it was hard-coded at 2.
    Usage example:
    ./remote_test.sh --num-workers 2 --sync-replicas
    3) Using device setter in mnist_replica.py
Change: 119599547
上级 ea6cdc0d
......@@ -110,6 +110,7 @@ filegroup(
"//tensorflow/tensorboard/lib:all_files",
"//tensorflow/tensorboard/lib/python:all_files",
"//tensorflow/tensorboard/scripts:all_files",
"//tensorflow/tools/dist_test/server:all_files",
"//tensorflow/tools/docker:all_files",
"//tensorflow/tools/docker/notebooks:all_files",
"//tensorflow/tools/docs:all_files",
......
......@@ -53,13 +53,49 @@ The IP address above is a dummy example. Such a cluster may have been set up
using the command described at the end of the previous section.
**Building the test server Docker image**
**Asynchronous and synchronous parameter updates**
There are two modes for the coordination of the parameters from multiple
workers: asynchronous and synchrnous.
In the asynchronous mode, the parameter updates (gradients) from the workers
are applied to the parameters without any explicit coordination. This is the
default mode in the tests.
In the synchronous mode, a certain number of parameter updates are aggregated
from the model replicas before the update is applied to the model parameters.
To use this mode, do:
# For remote testing
./remote_test.sh --sync-replicas
# For local testing
./local_test.sh --sync-replicas
**Specifying the number of workers**
You can specify the number of workers by using the --num-workers option flag,
e.g.,
# For remote testing
./remote_test.sh --num-workers 4
# For local testing
./local_test.sh --num-workers 4
**Building the GRPC server Docker image**
To build the Docker image for a test server of TensorFlow distributed runtime,
run:
./build_server.sh <docker_image_name>
**Using the GRPC server Docker image**
To launch a container as a TensorFlow GRPC server, do as the following example:
docker run tensorflow/tf_grpc_server --cluster_spec="worker|localhost:2222;foo:2222,ps|bar:2222;qux:2222" --job_name=worker --task_id=0
**Generating configuration file for TensorFlow k8s clusters**
......@@ -71,6 +107,15 @@ workers and parameter servers. For example:
--num_workers 2 \
--num_parameter_servers 2 \
--grpc_port 2222 \
--request_load_balancer \
--docker_image "tensorflow/tf_grpc_test_server" \
--request_load_balancer true \
--docker_image "tensorflow/tf_grpc_server" \
> tf-k8s-with-lb.yaml
The yaml configuration file generated in the previous step can be used to a
create a k8s cluster running the specified numbers of worker and parameter
servers. For example:
kubectl create -f tf-k8s-with-lb.yaml
See [Kubernetes kubectl documentation]
(http://kubernetes.io/docs/user-guide/kubectl-overview/) for more details.
......@@ -16,7 +16,11 @@
#
# Builds the test server for distributed (GRPC) TensorFlow
#
# Usage: build_server.sh <docker_image_name>
# Usage: build_server.sh <docker_image_name> [--test]
#
# The optional flag --test lets the script to use the Dockerfile for the
# testing GRPC server. Without the flag, the script will build the non-test
# GRPC server.
#
# Note that the Dockerfile is located in ./server/ but the docker build should
# use the current directory as the context.
......@@ -29,16 +33,28 @@ die() {
}
# Check arguments
if [[ $# != 1 ]]; then
die "Usage: $0 <docker_image_name>"
if [[ $# != 1 ]] && [[ $# != 2 ]]; then
die "Usage: $0 <docker_image_name> [--test]"
fi
DOCKER_IMG_NAME=$1
shift
# Current script directory
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
DOCKER_FILE="${DIR}/server/Dockerfile"
if [[ $1 == "--test" ]]; then
DOCKER_FILE="${DIR}/server/Dockerfile.test"
fi
echo "Using Docker file: ${DOCKER_FILE}"
if [[ ! -f "${DOCKER_FILE}" ]]; then
die "ERROR: Unable to find dockerfile: ${DOCKER_FILE}"
fi
echo "Dockerfile: ${DOCKER_FILE}"
# Call docker build
docker build --no-cache -t "${DOCKER_IMG_NAME}" \
-f "${DIR}/server/Dockerfile" \
-f "${DOCKER_FILE}" \
"${DIR}"
......@@ -70,7 +70,7 @@ fi
COUNTER=1
while true; do
((COUNTER++))
docker run --net=host --privileged ${DOCKER_ENV} \
docker run --rm --net=host --privileged ${DOCKER_ENV} \
-v ${HOST_K8S_DIR}:/local/kubernetes \
${DOCKER_IMG_NAME} \
/var/tf-k8s/local/start_local_k8s_service.sh
......
......@@ -20,15 +20,34 @@
# This script assumes that a TensorFlow cluster is already running on the
# local machine and can be controlled by the "kubectl" binary.
#
# Usage: test_local_tf_cluster.sh
# Usage: test_local_tf_cluster.sh <NUM_WORKERS> <NUM_PARAMETER_SERVERS>
# [--sync-replicas]
#
# --sync-replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
export GCLOUD_BIN=/usr/local/bin/gcloud
export TF_DIST_LOCAL_CLUSTER=1
# TODO(cais): Do not hard-code the numbers of workers and ps
NUM_WORKERS=2
NUM_PARAMETER_SERVERS=2
# Parse input arguments
if [[ $# == 0 ]] || [[ $# == 1 ]]; then
echo "Usage: $0 <NUM_WORKERS> <NUM_PARAMETER_SERVERS>"
exit 1
fi
NUM_WORKERS=$1
NUM_PARAMETER_SERVERS=$2
SYNC_REPLICAS_FLAG=""
if [[ $3 == "--sync-replicas" ]]; then
SYNC_REPLICAS_FLAG="--sync-replicas"
fi
echo "NUM_WORKERS: ${NUM_WORKERS}"
echo "NUM_PARAMETER_SERVERS: ${NUM_PARAMETER_SERVERS}"
echo "SYNC_REPLICAS_FLAG: ${SYNC_REPLICAS_FLAG}"
# Get current script directory
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
......@@ -73,13 +92,28 @@ if [[ -z "${DOCKER_CONTAINER_ID}" ]]; then
die "FAILED to determine worker0 Docker container ID"
fi
export TF_DIST_GRPC_SERVER_URL="grpc://tf-worker0:2222"
GRPC_ENV="TF_DIST_GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}"
WORKER_URLS=""
IDX=0
while true; do
WORKER_URLS="${WORKER_URLS},grpc://tf-worker${IDX}:2222"
((IDX++))
if [[ ${IDX} == ${NUM_WORKERS} ]]; then
break
fi
done
echo "Worker URLs: ${WORKER_URLS}"
export TF_DIST_GRPC_SERVER_URLS="${WORKER_URLS}"
GRPC_ENV="TF_DIST_GRPC_SERVER_URLS=${TF_DIST_GRPC_SERVER_URLS}"
CMD="${GRPC_ENV} /var/tf-k8s/scripts/dist_test.sh "\
"--num-workers ${NUM_WORKERS} "\
"--num-parameter-servers ${NUM_PARAMETER_SERVERS} "\
"${SYNC_REPLICAS_FLAG}"
docker exec \
${DOCKER_CONTAINER_ID} \
/bin/bash -c \
"${GRPC_ENV} /var/tf-k8s/scripts/dist_test.sh"
docker exec ${DOCKER_CONTAINER_ID} /bin/bash -c "${CMD}"
if [[ $? != "0" ]]; then
die "Test of local k8s TensorFlow cluster FAILED"
......
......@@ -25,11 +25,25 @@
# and run the distributed test suite.
#
# Usage: local_test.sh [--leave-container-running]
# [--num-workers <NUM_WORKERS>]
# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
# [--sync-replicas]
#
# Arguments:
# --leave-container-running: Do not stop the docker-in-docker container after
# the termination of the tests, e.g., for debugging
#
# --num-workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
# --num-parameter-server <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
# --sync-replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
# In addition, this script obeys the following environment variables:
# TF_DIST_SERVER_DOCKER_IMAGE: overrides the default docker image to launch
# TensorFlow (GRPC) servers with
......@@ -48,6 +62,34 @@ get_container_id_by_image_name() {
echo $(docker ps | grep $1 | awk '{print $1}')
}
# Parse input arguments
LEAVE_CONTAINER_RUNNING=0
NUM_WORKERS=2
NUM_PARAMETER_SERVERS=2
SYNC_REPLICAS=0
while true; do
if [[ $1 == "--leave-container-running" ]]; then
LEAVE_CONTAINER_RUNNING=1
elif [[ $1 == "--num-workers" ]]; then
NUM_WORKERS=$2
elif [[ $1 == "--num-parameter-servers" ]]; then
NUM_PARAMETER_SERVERS=$2
elif [[ $1 == "--sync-replicas" ]]; then
SYNC_REPLICAS=1
fi
shift
if [[ -z $1 ]]; then
break
fi
done
echo "LEAVE_CONTAINER_RUNNING: ${LEAVE_CONTAINER_RUNNING}"
echo "NUM_WORKERS: ${NUM_WORKERS}"
echo "NUM_PARAMETER_SERVERS: ${NUM_PARAMETER_SERVERS}"
echo "SYNC_REPLICAS: ${SYNC_REPLICAS}"
# Current script directory
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
......@@ -126,12 +168,18 @@ echo "Launching k8s tf cluster and tests in container ${DIND_ID} ..."
echo ""
# Launch k8s tf cluster in the docker-in-docker container and perform tests
SYNC_REPLICAS_FLAG=""
if [[ ${SYNC_REPLICAS} == "1" ]]; then
SYNC_REPLICAS_FLAG="--sync-replicas"
fi
docker exec ${DIND_ID} \
/var/tf-k8s/local/test_local_tf_cluster.sh
/var/tf-k8s/local/test_local_tf_cluster.sh \
${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} ${SYNC_REPLICAS_FLAG}
TEST_RES=$?
# Tear down: stop docker-in-docker container
if [[ $1 != "--leave-container-running" ]]; then
if [[ ${LEAVE_CONTAINER_RUNNING} == "0" ]]; then
echo ""
echo "Stopping docker-in-docker container ${DIND_ID}"
......@@ -140,7 +188,7 @@ if [[ $1 != "--leave-container-running" ]]; then
echo ""
else
echo "Will not terminate DIND container ${DIND_ID}"
echo "Will NOT terminate DIND container ${DIND_ID}"
fi
if [[ "${TEST_RES}" != "0" ]]; then
......
......@@ -42,8 +42,6 @@ import sys
import tempfile
import time
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
......@@ -52,83 +50,203 @@ flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
"Directory for storing mnist data")
flags.DEFINE_boolean("download_only", False,
"""Only perform downloading of data; Do not proceed to
model definition or training""")
"Only perform downloading of data; Do not proceed to "
"session preparation, model definition or training")
flags.DEFINE_integer("worker_index", 0,
"""Worker task index, should be >= 0. worker_index=0 is
the master worker task the performs the variable
initialization""")
"Worker task index, should be >= 0. worker_index=0 is "
"the master worker task the performs the variable "
"initialization ")
flags.DEFINE_integer("num_workers", None,
"Total number of workers (must be >= 1)")
flags.DEFINE_integer("num_parameter_servers", 2,
"Total number of parameter servers (must be >= 1)")
flags.DEFINE_integer("replicas_to_aggregate", None,
"Number of replicas to aggregate before paramter update"
"is applied (For sync_replicas mode only; default: "
"num_workers)")
flags.DEFINE_integer("grpc_port", 2222,
"TensorFlow GRPC port")
flags.DEFINE_integer("hidden_units", 100,
"Number of units in the hidden layer of the NN")
flags.DEFINE_integer("train_steps", 50, "Number of training steps")
flags.DEFINE_integer("train_steps", 200,
"Number of (global) training steps to perform")
flags.DEFINE_integer("batch_size", 100, "Training batch size")
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
flags.DEFINE_string("worker_grpc_url", None,
"Worker GRPC URL (e.g., grpc://1.2.3.4:2222, or "
"grpc://tf-worker0:2222)")
flags.DEFINE_boolean("sync_replicas", False,
"Use the sync_replicas (synchronized replicas) mode, "
"wherein the parameter updates from workersare aggregated "
"before applied to avoid stale gradients")
FLAGS = flags.FLAGS
IMAGE_PIXELS = 28
if __name__ == "__main__":
PARAM_SERVER_PREFIX = "tf-ps" # Prefix of the parameter servers' domain names
WORKER_PREFIX = "tf-worker" # Prefix of the workers' domain names
def get_device_setter(num_parameter_servers, num_workers):
"""Get a device setter given number of servers in the cluster.
Given the numbers of parameter servers and workers, construct a device
setter object using ClusterSpec.
Args:
num_parameter_servers: Number of parameter servers
num_workers: Number of workers
Returns:
Device setter object.
"""
ps_spec = []
for j in range(num_parameter_servers):
ps_spec.append("%s%d:%d" % (PARAM_SERVER_PREFIX, j, FLAGS.grpc_port))
worker_spec = []
for k in range(num_workers):
worker_spec.append("%s%d:%d" % (WORKER_PREFIX, k, FLAGS.grpc_port))
cluster_spec = tf.train.ClusterSpec({
"ps": ps_spec,
"worker": worker_spec})
# Get device setter from the cluster spec
return tf.train.replica_device_setter(cluster=cluster_spec)
def main(unused_argv):
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
if FLAGS.download_only:
sys.exit(0)
print("Worker GRPC URL: %s" % FLAGS.worker_grpc_url)
print("Worker index = %d" % FLAGS.worker_index)
print("Number of workers = %d" % FLAGS.num_workers)
# Sanity check on the number of workers and the worker index
if FLAGS.worker_index >= FLAGS.num_workers:
raise ValueError("Worker index %d exceeds number of workers %d " %
(FLAGS.worker_index, FLAGS.num_workers))
# Sanity check on the number of parameter servers
if FLAGS.num_parameter_servers <= 0:
raise ValueError("Invalid num_parameter_servers value: %d" %
FLAGS.num_parameter_servers)
is_chief = (FLAGS.worker_index == 0)
if FLAGS.sync_replicas:
if FLAGS.replicas_to_aggregate is None:
replicas_to_aggregate = FLAGS.num_workers
else:
replicas_to_aggregate = FLAGS.replicas_to_aggregate
# Construct device setter object
device_setter = get_device_setter(FLAGS.num_parameter_servers,
FLAGS.num_workers)
# The device setter will automatically place Variables ops on separate
# parameter servers (ps). The non-Variable ops will be placed on the workers.
with tf.device(device_setter):
global_step = tf.Variable(0, name="global_step", trainable=False)
with tf.Graph().as_default():
# Variables of the hidden layer
with tf.device("/job:ps/task:0"):
hid_w = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
hid_w = tf.Variable(
tf.truncated_normal([IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
stddev=1.0 / IMAGE_PIXELS), name="hid_w")
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
# Variables of the softmax layer
with tf.device("/job:ps/task:1"):
sm_w = tf.Variable(
tf.truncated_normal([FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
sm_w = tf.Variable(
tf.truncated_normal([FLAGS.hidden_units, 10],
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
name="sm_w")
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
# Ops: located on the worker specified with FLAGS.worker_index
with tf.device("/job:worker/task:%d" % FLAGS.worker_index):
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
y_ = tf.placeholder(tf.float32, [None, 10])
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
hid = tf.nn.relu(hid_lin)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ *
tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
train_step = tf.train.AdamOptimizer(
FLAGS.learning_rate).minimize(cross_entropy)
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
cross_entropy = -tf.reduce_sum(y_ *
tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
opt = tf.train.SyncReplicasOptimizer(
opt,
replicas_to_aggregate=replicas_to_aggregate,
total_num_replicas=FLAGS.num_workers,
replica_id=FLAGS.worker_index,
name="mnist_sync_replicas")
train_step = opt.minimize(cross_entropy,
global_step=global_step)
if FLAGS.sync_replicas and is_chief:
# Initial token and chief queue runners required by the sync_replicas mode
chief_queue_runner = opt.get_chief_queue_runner()
init_tokens_op = opt.get_init_tokens_op()
init_op = tf.initialize_all_variables()
train_dir = tempfile.mkdtemp()
print(FLAGS.worker_index)
sv = tf.train.Supervisor(logdir=train_dir,
is_chief=(FLAGS.worker_index == 0))
sv = tf.train.Supervisor(is_chief=is_chief,
logdir=train_dir,
init_op=init_op,
recovery_wait_secs=1,
global_step=global_step)
sess_config = tf.ConfigProto(
allow_soft_placement=True,
log_device_placement=True,
device_filters=["/job:ps", "/job:worker/task:%d" % FLAGS.worker_index])
# The chief worker (worker_index==0) session will prepare the session,
# while the remaining workers will wait for the preparation to complete.
sess = sv.prepare_or_wait_for_session(FLAGS.worker_grpc_url)
if is_chief:
print("Worker %d: Initializing session..." % FLAGS.worker_index)
else:
print("Worker %d: Waiting for session to be initialized..." %
FLAGS.worker_index)
sess = sv.prepare_or_wait_for_session(FLAGS.worker_grpc_url,
config=sess_config)
print("Worker %d: Session initialization complete." % FLAGS.worker_index)
if FLAGS.sync_replicas and is_chief:
# Chief worker will start the chief queue runner and call the init op
print("Starting chief queue runner and running init_tokens_op")
sv.start_queue_runners(sess, [chief_queue_runner])
sess.run(init_tokens_op)
# Perform training
time_begin = time.time()
print("Training begins @ %f" % time_begin)
# TODO(cais): terminate when a global step counter reaches FLAGS.train_steps
for i in xrange(FLAGS.train_steps):
local_step = 0
while True:
# Training feed
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
train_feed = {x: batch_xs,
y_: batch_ys}
sess.run(train_step, feed_dict=train_feed)
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
local_step += 1
now = time.time()
print("%f: Worker %d: training step %d done (global step: %d)" %
(now, FLAGS.worker_index, local_step, step))
if step >= FLAGS.train_steps:
break
time_end = time.time()
print("Training ends @ %f" % time_end)
......@@ -142,3 +260,6 @@ if __name__ == "__main__":
print("After %d training step(s), validation cross entropy = %g" %
(FLAGS.train_steps, val_xent))
if __name__ == "__main__":
tf.app.run()
......@@ -21,11 +21,26 @@
#
# Usage:
# remote_test.sh [--setup-cluster-only]
# [--num-workers <NUM_WORKERS>]
# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
# [--sync-replicas]
#
# Arguments:
# --setup-cluster-only:
# Setup the TensorFlow k8s cluster only, and do not perform testing of
# the distributed runtime.
#
# --num-workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
# --num-parameter-server <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
# --sync-replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
#
# If any of the following environment variable has non-empty values, it will
# be mapped into the docker container to override the default values (see
......@@ -86,7 +101,7 @@ docker build ${NO_CACHE_FLAG} \
-t ${DOCKER_IMG_NAME} -f "${DIR}/Dockerfile" "${DIR}"
KEY_FILE_DIR=${TF_DIST_GCLOUD_KEY_FILE_DIR:-"${HOME}/gcloud-secrets"}
docker run -v ${KEY_FILE_DIR}:/var/gcloud/secrets \
docker run --rm -v ${KEY_FILE_DIR}:/var/gcloud/secrets \
${DOCKER_ENV_FLAGS} \
${DOCKER_IMG_NAME} \
/var/tf-dist-test/scripts/dist_test.sh $@
......@@ -160,6 +160,7 @@ if [[ ! -f "${K8S_YAML}" ]]; then
else
echo "Generated yaml configuration file for k8s TensorFlow cluster: "\
"${K8S_YAML}"
cat "${K8S_YAML}"
fi
# Create tf k8s container cluster
......@@ -167,7 +168,9 @@ fi
# Wait for external IP of worker services to become available
get_tf_worker_external_ip() {
echo $("${KUBECTL_BIN}" get svc | grep "^tf-worker0" | \
# Usage: gen_tf_worker_external_ip <WORKER_INDEX>
# E.g., gen_tf_worker_external_ip 2
echo $("${KUBECTL_BIN}" get svc | grep "^tf-worker${1}" | \
awk '{print $3}' | grep -E "[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+")
}
......@@ -184,15 +187,34 @@ if [[ ${IS_LOCAL_CLUSTER} == "0" ]]; then
"of tf-worker0 service to emerge"
fi
SVC_EXTERN_IP=$(get_tf_worker_external_ip)
EXTERN_IPS=""
WORKER_INDEX=0
N_AVAILABLE_EXTERNAL_IPS=0
while true; do
SVC_EXTERN_IP=$(get_tf_worker_external_ip ${WORKER_INDEX})
if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
break
if [[ ! -z "${SVC_EXTERN_IP}" ]]; then
EXTERN_IPS="${EXTERN_IPS} ${SVC_EXTERN_IP}"
((N_AVAILABLE_EXTERNAL_IPS++))
fi
((WORKER_INDEX++))
if [[ ${WORKER_INDEX} == ${NUM_WORKERS} ]]; then
break;
fi
done
if [[ ${N_AVAILABLE_EXTERNAL_IPS} == ${NUM_WORKERS} ]]; then
break;
fi
done
GRPC_SERVER_URL="grpc://${SVC_EXTERN_IP}:${GRPC_PORT}"
echo "GRPC URL of tf-worker0: ${GRPC_SERVER_URL}"
GRPC_SERVER_URLS=""
for IP in ${EXTERN_IPS}; do
GRPC_SERVER_URLS="${GRPC_SERVER_URLS} grpc://${IP}:${GRPC_PORT}"
done
echo "GRPC URLs of tf-workers: ${GRPC_SERVER_URLS}"
else
echo "Waiting for tf pods to be all running..."
......
......@@ -19,10 +19,22 @@
# grpc pods and service set up.
#
# Usage:
# dist_mnist_test.sh <worker_grpc_url>
# dist_mnist_test.sh <worker_grpc_urls>
# [--num-workers <NUM_WORKERS>]
# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
# [--sync-replicas]
#
# worker_grp_url is the IP address or the GRPC URL of the worker of the main
# worker session, e.g., grpc://1.2.3.4:2222
# --sync-replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
# worker_grp_url is the list of IP addresses or the GRPC URLs of the worker of
# the worker sessions, separated with spaces,
# e.g., "grpc://1.2.3.4:2222 grpc://5.6.7.8:2222"
#
# --num-workers <NUM_WORKERS>:
# Specifies the number of workers to run
# Configurations
......@@ -34,17 +46,55 @@ die() {
exit 1
}
if [[ $# != 1 ]]; then
die "Usage: $0 <WORKER_GRPC_URL>"
if [[ $# == "0" ]]; then
die "Usage: $0 <WORKER_GRPC_URLS> [--num-workers <NUM_WORKERS>] "\
"[--num-parameter-servers <NUM_PARAMETER_SERVERS>] [--sync-replicas]"
fi
WORKER_GRPC_URL=$1
# Verify the validity of the GRPC URL
if [[ -z $(echo "${WORKER_GRPC_URL}" | \
grep -E "^grpc://.+:[0-9]+") ]]; then
die "Invalid worker GRPC URL: \"${WORKER_GRPC_URL}\""
WORKER_GRPC_URLS=$1
shift
# Process additional input arguments
N_WORKERS=2 # Default value
N_PS=2 # Default value
SYNC_REPLICAS=0
while true; do
if [[ "$1" == "--num-workers" ]]; then
N_WORKERS=$2
elif [[ "$1" == "--num-parameter-servers" ]]; then
N_PS=$2
elif [[ "$1" == "--sync-replicas" ]]; then
SYNC_REPLICAS="1"
die "ERROR: --sync-replicas (synchronized-replicas) mode is not fully "\
"supported by this test yet."
# TODO(cais): Remove error message once sync-replicas is fully supported
fi
shift
if [[ -z "$1" ]]; then
break
fi
done
SYNC_REPLICAS_FLAG=""
if [[ ${SYNC_REPLICAS} == "1" ]]; then
SYNC_REPLICAS_FLAG="--sync_replicas"
fi
echo "N_WORKERS = ${N_WORKERS}"
echo "N_PS = ${N_PS}"
echo "SYNC_REPLICAS = ${SYNC_REPLICAS}"
echo "SYNC_REPLICAS_FLAG = ${SYNC_REPLICAS_FLAG}"
# Verify the validity of the GRPC URLs
for WORKER_GRPC_URL in ${WORKER_GRPC_URLS}; do
if [[ -z $(echo "${WORKER_GRPC_URL}" | \
grep -E "^grpc://.+:[0-9]+") ]]; then
die "Invalid worker GRPC URL: \"${WORKER_GRPC_URL}\""
fi
done
# Current working directory
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PY_DIR=$(dirname "${DIR}")/python
......@@ -55,37 +105,43 @@ WKR_LOG_PREFIX="/tmp/worker"
# First, download the data from a single process, to avoid race-condition
# during data downloading
WORKER_GRPC_URL_0=$(echo ${WORKER_GRPC_URLS} | awk '{print $1}')
timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
--download_only=True || \
--worker_grpc_url="${WORKER_GRPC_URL_0}" \
--worker_index=0 \
--num_workers=${N_WORKERS} \
--num_parameter_servers=${N_PS} \
${SYNC_REPLICAS_FLAG} \
--download_only || \
die "Download-only step of MNIST replica FAILED"
# Run a number of workers in parallel
N_WORKERS=2
echo "${N_WORKERS} worker process(es) running in parallel..."
INDICES=""
IDX=0
URLS=($WORKER_GRPC_URLS)
while true; do
timeout ${TIMEOUT} \
python "${MNIST_REPLICA}" \
--worker_grpc_url="${WORKER_GRPC_URL}" \
--worker_index=${IDX} 2>&1 > \
"${WKR_LOG_PREFIX}${IDX}.log" &
# TODO(cais): have each trainer process contact a different worker once
# supervisor and sync_replicas etc. are all working in OSS TensorFlow.
WORKER_GRPC_URL="${URLS[IDX]}"
timeout ${TIMEOUT} python "${MNIST_REPLICA}" \
--worker_grpc_url="${WORKER_GRPC_URL}" \
--worker_index=${IDX} \
--num_workers=${N_WORKERS} \
--num_parameter_servers=${N_PS} \
${SYNC_REPLICAS_FLAG} 2>&1 | tee "${WKR_LOG_PREFIX}${IDX}.log" &
echo "Worker ${IDX}: "
echo " GRPC URL: ${WORKER_GRPC_URL}"
echo " log file: ${WKR_LOG_PREFIX}${IDX}.log"
INDICES="${INDICES} ${IDX}"
((IDX++))
if [[ $(echo "${IDX}==${N_WORKERS}" | bc -l) == "1" ]]; then
if [[ "${IDX}" == "${N_WORKERS}" ]]; then
break
fi
done
# Function for getting final validation cross entropy from worker log files
get_final_val_xent() {
echo $(cat $1 | grep "^After.*validation cross entropy = " | \
awk '{print $NF}')
}
# Poll until all final validation cross entropy values become available or
# operation times out
COUNTER=0
......@@ -97,20 +153,19 @@ while true; do
fi
N_AVAIL=0
VAL_XENTS=""
VAL_XENT=""
for N in ${INDICES}; do
VAL_XENT=$(get_final_val_xent "${WKR_LOG_PREFIX}${N}.log")
if [[ ! -z ${VAL_XENT} ]]; then
if [[ ! -z $(grep "Training ends " "${WKR_LOG_PREFIX}${N}.log") ]]; then
((N_AVAIL++))
VAL_XENTS="${VAL_XENTS} ${VAL_XENT}"
fi
done
if [[ "${N_AVAIL}" == "2" ]]; then
if [[ "${N_AVAIL}" == "${N_WORKERS}" ]]; then
# Print out the content of the log files
for M in ${INDICES}; do
ORD=$(expr ${M} + 1)
echo "==================================================="
echo "=== Log file from worker ${M} ==="
echo "=== Log file from worker ${ORD} / ${N_WORKERS} ==="
cat "${WKR_LOG_PREFIX}${M}.log"
echo "==================================================="
echo ""
......@@ -122,16 +177,20 @@ while true; do
fi
done
# Function for getting final validation cross entropy from worker log files
get_final_val_xent() {
echo $(cat $1 | grep "^After.*validation cross entropy = " | \
awk '{print $NF}')
}
VAL_XENT=$(get_final_val_xent "${WKR_LOG_PREFIX}0.log")
# Sanity check on the validation entropies
# TODO(cais): In addition to this basic sanity check, we could run the training
# with 1 and 2 workers, each for a few times and use scipy.stats to do a t-test
# to verify tha tthe 2-worker training gives significantly lower final cross
# entropy
VAL_XENTS=(${VAL_XENTS})
for N in ${INDICES}; do
echo "Final validation cross entropy from worker${N}: ${VAL_XENTS[N]}"
if [[ $(echo "${VAL_XENTS[N]}>0" | bc -l) != "1" ]]; then
die "Sanity checks on the final validation cross entropy values FAILED"
fi
done
echo "Final validation cross entropy from worker0: ${VAL_XENT}"
if [[ $(echo "${VAL_XENT}>0" | bc -l) != "1" ]]; then
die "Sanity checks on the final validation cross entropy values FAILED"
fi
......@@ -26,64 +26,113 @@
#
# Usage:
# dist_test.sh [--setup-cluster-only]
# [--num-workers <NUM_WORKERS>]
# [--num-parameter-servers <NUM_PARAMETER_SERVERS>]
# [--sync-replicas]
#
# --setup-cluster-only:
# Lets the script only set up the k8s container network
#
# --num-workers <NUM_WORKERS>:
# Specifies the number of worker pods to start
#
# --num-parameter-server <NUM_PARAMETER_SERVERS>:
# Specifies the number of parameter servers to start
#
# --sync-replicas
# Use the synchronized-replica mode. The parameter updates from the replicas
# (workers) will be aggregated before applied, which avoids stale parameter
# updates.
#
# --setup-cluster-only lets the script only set up the k8s container network
#
# This script obeys values in the folllowing environment variables:
# TF_DIST_GRPC_SERVER_URL: If it is set to a valid grpc server url (e.g.,
# (grpc://1.2.3.4:2222), the script will bypass
# the cluster setup and teardown processes and
# just use this URL.
# TF_DIST_GRPC_SERVER_URLS: If it is set to a list of valid server urls,
# separated with spaces or commas
# (e.g., "grpc://1.2.3.4:2222 grpc//5.6.7.8:2222"),
# the script will bypass the cluster setup and
# teardown processes and just use this URL.
# Configurations
NUM_WORKERS=2 # Number of worker container
NUM_PARAMETER_SERVERS=2 # Number of parameter servers
# Helper functions
die() {
echo $@
exit 1
}
# Parse input arguments: number of workers
# Default values:
NUM_WORKERS=2 # Number of worker container
NUM_PARAMETER_SERVERS=2 # Number of parameter servers
SYNC_REPLICAS=0
SETUP_CLUSTER_ONLY=0
while true; do
if [[ "$1" == "--num-workers" ]]; then
NUM_WORKERS=$2
elif [[ "$1" == "--num-parameter-servers" ]]; then
NUM_PARAMETER_SERVERS=$2
elif [[ "$1" == "--sync-replicas" ]]; then
SYNC_REPLICAS=1
elif [[ "$1" == "--setup-cluster-only" ]]; then
SETUP_CLUSTER_ONLY=1
fi
shift
if [[ -z "$1" ]]; then
break
fi
done
echo "NUM_WORKERS = ${NUM_WORKERS}"
echo "NUM_PARAMETER_SERVERS = ${NUM_PARAMETER_SERVERS}"
echo "SETUP_CLUSTER_ONLY = ${SETUP_CLUSTER_ONLY}"
# gcloud operation timeout (steps)
GCLOUD_OP_MAX_STEPS=240
GRPC_SERVER_URL=${TF_DIST_GRPC_SERVER_URL}
if [[ ! -z ${TF_DIST_GRPC_SERVER_URLS} ]]; then
GRPC_SERVER_URLS=${TF_DIST_GRPC_SERVER_URLS}
GRPC_SERVER_URLS=$(echo ${GRPC_SERVER_URLS} | sed -e 's/,/ /g')
fi
# Report gcloud / GKE parameters
echo "GRPC_SERVER_URL: ${GRPC_SERVER_URL}"
echo "GRPC_SERVER_URLS: ${GRPC_SERVER_URLS}"
echo "SYNC_REPLICAS: ${SYNC_REPLICAS}"
# Get current script directory
DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# Locate path to kubectl binary
TEARDOWN_WHEN_DONE=1
if [[ ! -z "${GRPC_SERVER_URL}" ]]; then
if [[ ! -z "${GRPC_SERVER_URLS}" ]]; then
TEARDOWN_WHEN_DONE=0
# Verify the validity of the GRPC URL
if [[ -z $(echo "${GRPC_SERVER_URL}" | \
for GRPC_SERVER_URL in ${GRPC_SEVER_URLS}; do
if [[ -z $(echo "${GRPC_SERVER_URL}" | \
grep -E "^grpc://.+:[0-9]+") ]]; then
die "Invalid GRPC_SERVER_URL: \"${GRPC_SERVER_URL}\""
else
echo "The preset GRPC_SERVER_URL appears to be valid: ${GRPC_SERVER_URL}"
echo "Will bypass the TensorFlow k8s cluster setup and teardown process"
echo ""
fi
die "Invalid GRPC_SERVER_URL: \"${GRPC_SERVER_URL}\""
fi
done
echo "The preset GRPC_SERVER_URLS appears to be valid: ${GRPC_SERVER_URLS}"
echo "Will bypass the TensorFlow k8s cluster setup and teardown process"
echo ""
else
TMP=$(mktemp)
"${DIR}/create_tf_cluster.sh" ${NUM_WORKERS} ${NUM_PARAMETER_SERVERS} 2>&1 | \
tee "${TMP}" || \
die "Creation of TensorFlow k8s cluster FAILED"
GRPC_SERVER_URL=$(cat ${TMP} | grep "GRPC URL of tf-worker0: .*" | \
awk '{print $NF}')
if [[ -z "${GRPC_SERVER_URL}" ]]; then
die "FAILED to determine GRPC server URL"
GRPC_SERVER_URLS=$(cat ${TMP} | grep "GRPC URLs of tf-workers: .*" | \
sed -e 's/GRPC URLs of tf-workers://g')
if [[ $(echo ${GRPC_SERVER_URLS} | wc -w) != ${NUM_WORKERS} ]]; then
die "FAILED to determine GRPC server URLs of all workers"
fi
rm -f ${TMP}
if [[ $1 == "--setup-cluster-only" ]]; then
if [[ ${SETUP_CLUSTER_ONLY} == "1" ]]; then
echo "Skipping testing of distributed runtime due to "\
"option flag --setup-cluster-only"
exit 0
......@@ -97,10 +146,18 @@ if [[ ! -f "${MNIST_DIST_TEST_BIN}" ]]; then
"${MNIST_DIST_TEST_BIN}"
fi
echo "Performing distributed MNIST training through grpc session @ "\
"${GRPC_SERVER_URL}..."
echo "Performing distributed MNIST training through grpc sessions @ "\
"${GRPC_SERVER_URLS}..."
SYNC_REPLICAS_FLAG=""
if [[ ${SYNC_REPLICAS} == "1" ]]; then
SYNC_REPLICAS_FLAG="--sync-replicas"
fi
"${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URL}"
"${MNIST_DIST_TEST_BIN}" "${GRPC_SERVER_URLS}" \
--num-workers "${NUM_WORKERS}" \
--num-parameter-servers "${NUM_PARAMETER_SERVERS}" \
${SYNC_REPLICAS_FLAG}
if [[ $? == "0" ]]; then
echo "MNIST-replica test PASSED"
......
......@@ -180,11 +180,11 @@ def GenerateConfig(num_workers,
port=port,
worker_id=worker,
docker_image=docker_image,
cluster_spec=WorkerClusterSpec(num_workers,
num_param_servers,
port))
cluster_spec=WorkerClusterSpecString(num_workers,
num_param_servers,
port))
config += '---\n'
if worker == 0 and request_load_balancer:
if request_load_balancer:
config += WORKER_LB_SVC.format(port=port,
worker_id=worker)
else:
......@@ -197,9 +197,9 @@ def GenerateConfig(num_workers,
port=port,
param_server_id=param_server,
docker_image=docker_image,
cluster_spec=ParamServerClusterSpec(num_workers,
num_param_servers,
port))
cluster_spec=ParamServerClusterSpecString(num_workers,
num_param_servers,
port))
config += '---\n'
config += PARAM_SERVER_SVC.format(port=port,
param_server_id=param_server)
......@@ -208,23 +208,23 @@ def GenerateConfig(num_workers,
return config
def WorkerClusterSpec(num_workers,
num_param_servers,
port):
def WorkerClusterSpecString(num_workers,
num_param_servers,
port):
"""Generates worker cluster spec."""
return ClusterSpec(num_workers, num_param_servers, port)
return ClusterSpecString(num_workers, num_param_servers, port)
def ParamServerClusterSpec(num_workers,
num_param_servers,
port):
def ParamServerClusterSpecString(num_workers,
num_param_servers,
port):
"""Generates parameter server spec."""
return ClusterSpec(num_workers, num_param_servers, port)
return ClusterSpecString(num_workers, num_param_servers, port)
def ClusterSpec(num_workers,
num_param_servers,
port):
def ClusterSpecString(num_workers,
num_param_servers,
port):
"""Generates general cluster spec."""
spec = 'worker|'
for worker in range(num_workers):
......
# Description:
# TensorFlow GRPC distributed runtime server and tests
package(default_visibility = ["//visibility:private"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
py_library(
name = "grpc_tensorflow_server",
srcs = [
"grpc_tensorflow_server.py",
],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
"//tensorflow:tensorflow_py",
],
)
py_test(
name = "parse_cluster_spec_test",
size = "small",
srcs = [
"parse_cluster_spec_test.py",
],
main = "parse_cluster_spec_test.py",
srcs_version = "PY2AND3",
deps = [
":grpc_tensorflow_server",
"//tensorflow:tensorflow_py",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
......@@ -23,9 +23,7 @@ MAINTAINER Shanqing Cai <cais@google.com>
# Pick up some TF dependencies
RUN apt-get update && apt-get install -y \
bc \
curl \
dnsutils \
python-numpy \
python-pip \
&& \
......@@ -36,7 +34,7 @@ RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
rm get-pip.py
# Install TensorFlow CPU version.
# Install TensorFlow CPU version from nightly build
RUN pip --no-cache-dir install \
http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
......@@ -44,16 +42,5 @@ RUN pip --no-cache-dir install \
# server/grpc_tensorflow_server.py
ADD . /var/tf-k8s
# Download MNIST data for tests
RUN mkdir -p /tmp/mnist-data
RUN curl -o /tmp/mnist-data/train-labels-idx1-ubyte.gz \
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
RUN curl -o /tmp/mnist-data/train-images-idx3-ubyte.gz \
http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
RUN curl -o /tmp/mnist-data/t10k-labels-idx1-ubyte.gz \
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
# Container entry point
ENTRYPOINT ["/var/tf-k8s/server/grpc_tensorflow_server.py"]
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Test server for TensorFlow GRPC server
#
# To build the image, use ../build_server.sh --test
FROM ubuntu:14.04
MAINTAINER Shanqing Cai <cais@google.com>
# Pick up some TF dependencies
RUN apt-get update && apt-get install -y \
bc \
curl \
dnsutils \
python-numpy \
python-pip \
&& \
apt-get clean && \
rm -rf /var/lib/apt/lists/*
RUN curl -O https://bootstrap.pypa.io/get-pip.py && \
python get-pip.py && \
rm get-pip.py
# Install TensorFlow CPU version.
RUN pip --no-cache-dir install \
http://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.7.1-cp27-none-linux_x86_64.whl
# Copy files, including the GRPC server binary at
# server/grpc_tensorflow_server.py
ADD . /var/tf-k8s
# Download MNIST data for tests
RUN mkdir -p /tmp/mnist-data
RUN curl -o /tmp/mnist-data/train-labels-idx1-ubyte.gz \
http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
RUN curl -o /tmp/mnist-data/train-images-idx3-ubyte.gz \
http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
RUN curl -o /tmp/mnist-data/t10k-labels-idx1-ubyte.gz \
http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
RUN curl -o /tmp/mnist-data/t10k-images-idx3-ubyte.gz \
http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
# Container entry point
ENTRYPOINT ["/var/tf-k8s/server/grpc_tensorflow_server_wrapper.sh"]
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow GRPC server."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
......@@ -67,6 +67,9 @@ def parse_cluster_spec(cluster_spec, cluster):
job_strings = cluster_spec.split(",")
if not cluster_spec:
raise ValueError("Empty cluster_spec string")
for job_string in job_strings:
job_def = cluster.job.add()
......@@ -86,7 +89,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_tasks = job_string.split("|")[1].split(";")
for i in range(len(job_tasks)):
if not job_tasks[i]:
raise ValueError("Empty job_task string at position %d" % i)
raise ValueError("Empty task string at position %d" % i)
job_def.tasks[i] = job_tasks[i]
......@@ -96,7 +99,7 @@ def parse_cluster_spec(cluster_spec, cluster):
def main(unused_args):
# Create Protobuf ServerDef
server_def = tf.ServerDef(protocol="grpc")
server_def = tf.train.ServerDef(protocol="grpc")
# Cluster info
parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster)
......@@ -111,8 +114,8 @@ def main(unused_args):
raise ValueError("Invalid task_id: %d" % FLAGS.task_id)
server_def.task_index = FLAGS.task_id
# Create GrpcServer instance
server = tf.GrpcServer(server_def)
# Create GRPC Server instance
server = tf.train.Server(server_def)
# join() is blocking, unlike start()
server.join()
......
#!/usr/bin/env bash
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#
# Wrapper script for grpc_tensorflow_server.py in test server
LOG_FILE="/tmp/grpc_tensorflow_server.log"
SCRIPT_DIR=$( cd ${0%/*} && pwd -P )
touch "${LOG_FILE}"
python ${SCRIPT_DIR}/grpc_tensorflow_server.py $@ 2>&1 | tee "${LOG_FILE}"
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for cluster-spec string parser in GRPC TensorFlow server."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.tools.dist_test.server import grpc_tensorflow_server
class ParseClusterSpecStringTest(tf.test.TestCase):
def setUp(self):
self._cluster = tf.train.ServerDef(protocol="grpc").cluster
def test_parse_multi_jobs_sunnyday(self):
cluster_spec = ("worker|worker0:2220;worker1:2221;worker2:2222,"
"ps|ps0:3220;ps1:3221")
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
self.assertEqual(2, len(self._cluster.job))
self.assertEqual("worker", self._cluster.job[0].name)
self.assertEqual(3, len(self._cluster.job[0].tasks))
self.assertEqual("worker0:2220", self._cluster.job[0].tasks[0])
self.assertEqual("worker1:2221", self._cluster.job[0].tasks[1])
self.assertEqual("worker2:2222", self._cluster.job[0].tasks[2])
self.assertEqual("ps", self._cluster.job[1].name)
self.assertEqual(2, len(self._cluster.job[1].tasks))
self.assertEqual("ps0:3220", self._cluster.job[1].tasks[0])
self.assertEqual("ps1:3221", self._cluster.job[1].tasks[1])
def test_empty_cluster_spec_string(self):
cluster_spec = ""
with self.assertRaisesRegexp(ValueError,
"Empty cluster_spec string"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
def test_parse_misused_comma_for_semicolon(self):
cluster_spec = "worker|worker0:2220,worker1:2221"
with self.assertRaisesRegexp(ValueError,
"Not exactly one instance of \\'\\|\\'"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
def test_parse_misused_semicolon_for_comma(self):
cluster_spec = "worker|worker0:2220;ps|ps0:3220"
with self.assertRaisesRegexp(ValueError,
"Not exactly one instance of \\'\\|\\'"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
def test_parse_empty_job_name(self):
cluster_spec = "worker|worker0:2220,|ps0:3220"
with self.assertRaisesRegexp(ValueError,
"Empty job_name in cluster_spec"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
print(self._cluster)
def test_parse_empty_task(self):
cluster_spec = "worker|worker0:2220,ps|"
with self.assertRaisesRegexp(ValueError,
"Empty task string at position 0"):
grpc_tensorflow_server.parse_cluster_spec(cluster_spec, self._cluster)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册