提交 2bdf0b4e 编写于 作者: V Vijay Vasudevan 提交者: TensorFlower Gardener

Convert more flags use to argparse in dist_test

Change: 144278086
上级 e11fae78
......@@ -20,8 +20,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import json
import os
import sys
from six.moves import urllib
import tensorflow as tf
......@@ -30,28 +32,6 @@ from tensorflow.contrib.learn.python.learn import learn_runner
from tensorflow.contrib.learn.python.learn.estimators import run_config
# Define command-line flags
flags = tf.app.flags
flags.DEFINE_string("data_dir", "/tmp/census-data",
"Directory for storing the cesnsus data")
flags.DEFINE_string("model_dir", "/tmp/census_wide_and_deep_model",
"Directory for storing the model")
flags.DEFINE_string("output_dir", "", "Base output directory.")
flags.DEFINE_string("schedule", "local_run",
"Schedule to run for this experiment.")
flags.DEFINE_string("master_grpc_url", "",
"URL to master GRPC tensorflow server, e.g.,"
"grpc://127.0.0.1:2222")
flags.DEFINE_integer("num_parameter_servers", 0,
"Number of parameter servers")
flags.DEFINE_integer("worker_index", 0,
"Worker index (>=0)")
flags.DEFINE_integer("train_steps", 1000, "Number of training steps")
flags.DEFINE_integer("eval_steps", 1, "Number of evaluation steps")
FLAGS = flags.FLAGS
# Constants: Data download URLs
TRAIN_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data"
TEST_DATA_URL = "http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test"
......@@ -277,4 +257,62 @@ def main(unused_argv):
if __name__ == "__main__":
tf.app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--data_dir",
type=str,
default="/tmp/census-data",
help="Directory for storing the cesnsus data"
)
parser.add_argument(
"--model_dir",
type=str,
default="/tmp/census_wide_and_deep_model",
help="Directory for storing the model"
)
parser.add_argument(
"--output_dir",
type=str,
default="",
help="Base output directory."
)
parser.add_argument(
"--schedule",
type=str,
default="local_run",
help="Schedule to run for this experiment."
)
parser.add_argument(
"--master_grpc_url",
type=str,
default="",
help="URL to master GRPC tensorflow server, e.g.,grpc://127.0.0.1:2222"
)
parser.add_argument(
"--num_parameter_servers",
type=int,
default=0,
help="Number of parameter servers"
)
parser.add_argument(
"--worker_index",
type=int,
default=0,
help="Worker index (>=0)"
)
parser.add_argument(
"--train_steps",
type=int,
default=1000,
help="Number of training steps"
)
parser.add_argument(
"--eval_steps",
type=int,
default=1,
help="Number of evaluation steps"
)
global FLAGS # pylint:disable=global-at-module-level
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
......@@ -33,32 +33,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import sys
from tensorflow.core.protobuf import tensorflow_server_pb2
from tensorflow.python.platform import app
from tensorflow.python.platform import flags
from tensorflow.python.training import server_lib
FLAGS = flags.FLAGS
flags.DEFINE_string("cluster_spec", "", """Cluster spec: SPEC.
SPEC is <JOB>(,<JOB>)*,"
JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"
NAME is a valid job name ([a-z][0-9a-z]*),"
HOST is a hostname or IP address,"
PORT is a port number."
E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222""")
flags.DEFINE_string("job_name", "", "Job name: e.g., local")
flags.DEFINE_integer("task_id", 0, "Task index, e.g., 0")
flags.DEFINE_boolean("verbose", False, "Verbose mode")
def parse_cluster_spec(cluster_spec, cluster):
def parse_cluster_spec(cluster_spec, cluster, verbose=False):
"""Parse content of cluster_spec string and inject info into cluster protobuf.
Args:
cluster_spec: cluster specification string, e.g.,
"local|localhost:2222;localhost:2223"
cluster: cluster protobuf.
verbose: If verbose logging is requested.
Raises:
ValueError: if the cluster_spec string is invalid.
......@@ -82,7 +72,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_def.name = job_name
if FLAGS.verbose:
if verbose:
print("Added job named \"%s\"" % job_name)
job_tasks = job_string.split("|")[1].split(";")
......@@ -92,7 +82,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_def.tasks[i] = job_tasks[i]
if FLAGS.verbose:
if verbose:
print(" Added task \"%s\" to job \"%s\"" % (job_tasks[i], job_name))
......@@ -101,7 +91,7 @@ def main(unused_args):
server_def = tensorflow_server_pb2.ServerDef(protocol="grpc")
# Cluster info
parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster)
parse_cluster_spec(FLAGS.cluster_spec, server_def.cluster, FLAGS.verbose)
# Job name
if not FLAGS.job_name:
......@@ -121,4 +111,39 @@ def main(unused_args):
if __name__ == "__main__":
app.run()
parser = argparse.ArgumentParser()
parser.register("type", "bool", lambda v: v.lower() == "true")
parser.add_argument(
"--cluster_spec",
type=str,
default="",
help="""\
Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is
<NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name
([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a
port number." E.g., local|localhost:2222;localhost:2223,
ps|ps0:2222;ps1:2222\
"""
)
parser.add_argument(
"--job_name",
type=str,
default="",
help="Job name: e.g., local"
)
parser.add_argument(
"--task_id",
type=int,
default=0,
help="Task index, e.g., 0"
)
parser.add_argument(
"--verbose",
type="bool",
nargs="?",
const=True,
default=False,
help="Verbose mode"
)
FLAGS, unparsed = parser.parse_known_args()
app.run(main=main, argv=[sys.argv[0]] + unparsed)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册