Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
2bdf0b4e
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
2bdf0b4e
编写于
1月 11, 2017
作者:
V
Vijay Vasudevan
提交者:
TensorFlower Gardener
1月 11, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Convert more flags use to argparse in dist_test
Change: 144278086
上级
e11fae78
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
105 addition
and
42 deletion
+105
-42
tensorflow/tools/dist_test/python/census_widendeep.py
tensorflow/tools/dist_test/python/census_widendeep.py
+61
-23
tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
+44
-19
未找到文件。
tensorflow/tools/dist_test/python/census_widendeep.py
浏览文件 @
2bdf0b4e
...
...
@@ -20,8 +20,10 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
json
import
os
import
sys
from
six.moves
import
urllib
import
tensorflow
as
tf
...
...
@@ -30,28 +32,6 @@ from tensorflow.contrib.learn.python.learn import learn_runner
from
tensorflow.contrib.learn.python.learn.estimators
import
run_config
# Define command-line flags
flags
=
tf
.
app
.
flags
flags
.
DEFINE_string
(
"data_dir"
,
"/tmp/census-data"
,
"Directory for storing the cesnsus data"
)
flags
.
DEFINE_string
(
"model_dir"
,
"/tmp/census_wide_and_deep_model"
,
"Directory for storing the model"
)
flags
.
DEFINE_string
(
"output_dir"
,
""
,
"Base output directory."
)
flags
.
DEFINE_string
(
"schedule"
,
"local_run"
,
"Schedule to run for this experiment."
)
flags
.
DEFINE_string
(
"master_grpc_url"
,
""
,
"URL to master GRPC tensorflow server, e.g.,"
"grpc://127.0.0.1:2222"
)
flags
.
DEFINE_integer
(
"num_parameter_servers"
,
0
,
"Number of parameter servers"
)
flags
.
DEFINE_integer
(
"worker_index"
,
0
,
"Worker index (>=0)"
)
flags
.
DEFINE_integer
(
"train_steps"
,
1000
,
"Number of training steps"
)
flags
.
DEFINE_integer
(
"eval_steps"
,
1
,
"Number of evaluation steps"
)
FLAGS
=
flags
.
FLAGS
# Constants: Data download URLs
TRAIN_DATA_URL
=
"http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.data"
TEST_DATA_URL
=
"http://mlr.cs.umass.edu/ml/machine-learning-databases/adult/adult.test"
...
...
@@ -277,4 +257,62 @@ def main(unused_argv):
if
__name__
==
"__main__"
:
tf
.
app
.
run
()
parser
=
argparse
.
ArgumentParser
()
parser
.
register
(
"type"
,
"bool"
,
lambda
v
:
v
.
lower
()
==
"true"
)
parser
.
add_argument
(
"--data_dir"
,
type
=
str
,
default
=
"/tmp/census-data"
,
help
=
"Directory for storing the cesnsus data"
)
parser
.
add_argument
(
"--model_dir"
,
type
=
str
,
default
=
"/tmp/census_wide_and_deep_model"
,
help
=
"Directory for storing the model"
)
parser
.
add_argument
(
"--output_dir"
,
type
=
str
,
default
=
""
,
help
=
"Base output directory."
)
parser
.
add_argument
(
"--schedule"
,
type
=
str
,
default
=
"local_run"
,
help
=
"Schedule to run for this experiment."
)
parser
.
add_argument
(
"--master_grpc_url"
,
type
=
str
,
default
=
""
,
help
=
"URL to master GRPC tensorflow server, e.g.,grpc://127.0.0.1:2222"
)
parser
.
add_argument
(
"--num_parameter_servers"
,
type
=
int
,
default
=
0
,
help
=
"Number of parameter servers"
)
parser
.
add_argument
(
"--worker_index"
,
type
=
int
,
default
=
0
,
help
=
"Worker index (>=0)"
)
parser
.
add_argument
(
"--train_steps"
,
type
=
int
,
default
=
1000
,
help
=
"Number of training steps"
)
parser
.
add_argument
(
"--eval_steps"
,
type
=
int
,
default
=
1
,
help
=
"Number of evaluation steps"
)
global
FLAGS
# pylint:disable=global-at-module-level
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
tf
.
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
tensorflow/tools/dist_test/server/grpc_tensorflow_server.py
浏览文件 @
2bdf0b4e
...
...
@@ -33,32 +33,22 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
sys
from
tensorflow.core.protobuf
import
tensorflow_server_pb2
from
tensorflow.python.platform
import
app
from
tensorflow.python.platform
import
flags
from
tensorflow.python.training
import
server_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_string
(
"cluster_spec"
,
""
,
"""Cluster spec: SPEC.
SPEC is <JOB>(,<JOB>)*,"
JOB is <NAME>|<HOST:PORT>(;<HOST:PORT>)*,"
NAME is a valid job name ([a-z][0-9a-z]*),"
HOST is a hostname or IP address,"
PORT is a port number."
E.g., local|localhost:2222;localhost:2223, ps|ps0:2222;ps1:2222"""
)
flags
.
DEFINE_string
(
"job_name"
,
""
,
"Job name: e.g., local"
)
flags
.
DEFINE_integer
(
"task_id"
,
0
,
"Task index, e.g., 0"
)
flags
.
DEFINE_boolean
(
"verbose"
,
False
,
"Verbose mode"
)
def
parse_cluster_spec
(
cluster_spec
,
cluster
):
def
parse_cluster_spec
(
cluster_spec
,
cluster
,
verbose
=
False
):
"""Parse content of cluster_spec string and inject info into cluster protobuf.
Args:
cluster_spec: cluster specification string, e.g.,
"local|localhost:2222;localhost:2223"
cluster: cluster protobuf.
verbose: If verbose logging is requested.
Raises:
ValueError: if the cluster_spec string is invalid.
...
...
@@ -82,7 +72,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_def
.
name
=
job_name
if
FLAGS
.
verbose
:
if
verbose
:
print
(
"Added job named
\"
%s
\"
"
%
job_name
)
job_tasks
=
job_string
.
split
(
"|"
)[
1
].
split
(
";"
)
...
...
@@ -92,7 +82,7 @@ def parse_cluster_spec(cluster_spec, cluster):
job_def
.
tasks
[
i
]
=
job_tasks
[
i
]
if
FLAGS
.
verbose
:
if
verbose
:
print
(
" Added task
\"
%s
\"
to job
\"
%s
\"
"
%
(
job_tasks
[
i
],
job_name
))
...
...
@@ -101,7 +91,7 @@ def main(unused_args):
server_def
=
tensorflow_server_pb2
.
ServerDef
(
protocol
=
"grpc"
)
# Cluster info
parse_cluster_spec
(
FLAGS
.
cluster_spec
,
server_def
.
cluster
)
parse_cluster_spec
(
FLAGS
.
cluster_spec
,
server_def
.
cluster
,
FLAGS
.
verbose
)
# Job name
if
not
FLAGS
.
job_name
:
...
...
@@ -121,4 +111,39 @@ def main(unused_args):
if
__name__
==
"__main__"
:
app
.
run
()
parser
=
argparse
.
ArgumentParser
()
parser
.
register
(
"type"
,
"bool"
,
lambda
v
:
v
.
lower
()
==
"true"
)
parser
.
add_argument
(
"--cluster_spec"
,
type
=
str
,
default
=
""
,
help
=
"""
\
Cluster spec: SPEC. SPEC is <JOB>(,<JOB>)*," JOB is
<NAME>|<HOST:PORT>(;<HOST:PORT>)*," NAME is a valid job name
([a-z][0-9a-z]*)," HOST is a hostname or IP address," PORT is a
port number." E.g., local|localhost:2222;localhost:2223,
ps|ps0:2222;ps1:2222
\
"""
)
parser
.
add_argument
(
"--job_name"
,
type
=
str
,
default
=
""
,
help
=
"Job name: e.g., local"
)
parser
.
add_argument
(
"--task_id"
,
type
=
int
,
default
=
0
,
help
=
"Task index, e.g., 0"
)
parser
.
add_argument
(
"--verbose"
,
type
=
"bool"
,
nargs
=
"?"
,
const
=
True
,
default
=
False
,
help
=
"Verbose mode"
)
FLAGS
,
unparsed
=
parser
.
parse_known_args
()
app
.
run
(
main
=
main
,
argv
=
[
sys
.
argv
[
0
]]
+
unparsed
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录