Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
848e56c9
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
848e56c9
编写于
2月 22, 2019
作者:
B
Bruce Fontaine
提交者:
TensorFlower Gardener
2月 22, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Restrict tpu_embedding hosts to just worker jobs. Pass ClusterDef to TpuEmbedding.
PiperOrigin-RevId: 235210894
上级
85538ad4
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
68 addition
and
28 deletion
+68
-28
tensorflow/python/tpu/_tpu_estimator_embedding.py
tensorflow/python/tpu/_tpu_estimator_embedding.py
+9
-3
tensorflow/python/tpu/tpu_context.py
tensorflow/python/tpu/tpu_context.py
+4
-21
tensorflow/python/tpu/tpu_embedding.py
tensorflow/python/tpu/tpu_embedding.py
+12
-4
tensorflow/python/tpu/tpu_system_metadata.py
tensorflow/python/tpu/tpu_system_metadata.py
+43
-0
未找到文件。
tensorflow/python/tpu/_tpu_estimator_embedding.py
浏览文件 @
848e56c9
...
...
@@ -272,13 +272,13 @@ class EmbeddingConfig(object):
"""
def
__init__
(
self
,
embedding_config_spec
,
train_batch_size
,
eval_batch_size
,
num_hosts
,
num_cores
,
master
):
num_hosts
,
num_cores
,
run_config
):
self
.
_embedding_config_spec
=
embedding_config_spec
self
.
_train_batch_size
=
train_batch_size
self
.
_eval_batch_size
=
eval_batch_size
self
.
_num_hosts
=
num_hosts
self
.
_num_cores
=
num_cores
self
.
_
master
=
master
self
.
_
run_config
=
run_config
self
.
_table_to_config_dict
,
self
.
_feature_to_table_dict
=
(
get_tpu_embedding_config_from_feature_columns
(
...
...
@@ -306,13 +306,19 @@ class EmbeddingConfig(object):
else
:
raise
ValueError
(
'Mode {} is not supported.'
.
format
(
mode
))
master
=
(
self
.
_run_config
.
evaluation_master
if
mode
==
model_fn_lib
.
ModeKeys
.
EVAL
else
self
.
_run_config
.
master
)
cluster_def
=
(
self
.
_run_config
.
session_config
.
cluster_def
if
self
.
_run_config
.
session_config
else
None
)
tpu_embedding_
=
tpu_embedding
.
TPUEmbedding
(
self
.
_table_to_config_dict
,
self
.
_feature_to_table_dict
,
batch_size
,
tpu_embedding_mode
,
self
.
_
master
,
master
,
self
.
_optimization_parameters
,
cluster_def
,
)
return
tpu_embedding_
...
...
tensorflow/python/tpu/tpu_context.py
浏览文件 @
848e56c9
...
...
@@ -313,7 +313,7 @@ class _InternalTPUContext(object):
if
self
.
_use_tpu
and
self
.
_embedding_config_spec
:
embedding_config
=
_tpu_estimator_embedding
.
EmbeddingConfig
(
self
.
_embedding_config_spec
,
self
.
_train_batch_size
,
self
.
_eval_batch_size
,
self
.
num_hosts
,
self
.
num_cores
,
master
)
self
.
_eval_batch_size
,
self
.
num_hosts
,
self
.
num_cores
,
self
.
config
)
if
not
embedding_config
.
has_embedding_tables
():
embedding_config
=
None
self
.
_lazy_embedding_config_dict
[
master
]
=
embedding_config
...
...
@@ -510,27 +510,10 @@ class _InternalTPUContext(object):
master
=
(
run_config
.
evaluation_master
if
mode
==
model_fn_lib
.
ModeKeys
.
EVAL
else
run_config
.
master
)
if
master
in
_LOCAL_MASTERS
:
return
None
cluster_def
=
(
run_config
.
session_config
.
cluster_def
if
run_config
.
session_config
else
None
)
if
(
not
run_config
.
session_config
or
not
run_config
.
session_config
.
cluster_def
.
job
):
return
_DEFAULT_JOB_NAME
cluster_def
=
run_config
.
session_config
.
cluster_def
job_names
=
set
([
job
.
name
for
job
in
cluster_def
.
job
])
if
_DEFAULT_JOB_NAME
in
job_names
:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise
ValueError
(
'Currently, tpu_worker is not an allowed job name.'
)
if
len
(
job_names
)
==
1
:
return
cluster_def
.
job
[
0
].
name
if
len
(
job_names
)
==
2
:
if
_DEFAULT_COORDINATOR_JOB_NAME
in
job_names
:
job_names
.
remove
(
_DEFAULT_COORDINATOR_JOB_NAME
)
return
job_names
.
pop
()
# TODO(b/67716447): Include more sophisticated heuristics.
raise
ValueError
(
'Could not infer TPU job name. Please specify a tpu_job_name as part '
'of your TPUConfig.'
)
return
tpu_system_metadata_lib
.
master_job
(
master
,
cluster_def
)
@
property
def
tpu_host_placement_function
(
self
):
...
...
tensorflow/python/tpu/tpu_embedding.py
浏览文件 @
848e56c9
...
...
@@ -308,7 +308,8 @@ class TPUEmbedding(object):
batch_size
,
mode
,
master
,
optimization_parameters
=
None
):
optimization_parameters
=
None
,
cluster_def
=
None
):
"""API for using TPU for embedding lookups.
Args:
...
...
@@ -324,6 +325,7 @@ class TPUEmbedding(object):
optimization_parameters: `AdagradParameters`, `AdamParameters`,
`Stochasticgradientdescentparameters`. Must be set in training and must
be `None` in inference.
cluster_def: A ClusterDef object describing the TPU cluster.
Raises:
ValueError: if any input is invalid.
...
...
@@ -341,14 +343,20 @@ class TPUEmbedding(object):
self
.
_batch_size
=
batch_size
self
.
_master
=
master
self
.
_cluster_def
=
cluster_def
self
.
_tpu_system_metadata
=
(
tpu_system_metadata_lib
.
_query_tpu_system_metadata
(
self
.
_master
))
# pylint: disable=protected-access
tpu_system_metadata_lib
.
_query_tpu_system_metadata
(
# pylint: disable=protected-access
self
.
_master
,
cluster_def
=
self
.
_cluster_def
))
if
self
.
_tpu_system_metadata
.
num_cores
==
0
:
raise
ValueError
(
'TPUEmbedding needs TPUs, but master {} does not have '
'TPUs.'
.
format
(
self
.
_master
))
self
.
_num_hosts
=
self
.
_tpu_system_metadata
.
num_hosts
self
.
_hosts
=
[
device
.
name
for
device
in
self
.
_tpu_system_metadata
.
devices
if
'device:CPU:'
in
device
.
name
]
master_job_name
=
tpu_system_metadata_lib
.
master_job
(
self
.
_master
,
self
.
_cluster_def
)
self
.
_hosts
=
sorted
([
device
.
name
for
device
in
self
.
_tpu_system_metadata
.
devices
if
'device:CPU:'
in
device
.
name
and
(
master_job_name
is
None
or
master_job_name
in
device
.
name
)])
self
.
_num_cores_per_host
=
self
.
_tpu_system_metadata
.
num_of_cores_per_host
self
.
_num_cores
=
self
.
_tpu_system_metadata
.
num_cores
...
...
tensorflow/python/tpu/tpu_system_metadata.py
浏览文件 @
848e56c9
...
...
@@ -34,6 +34,10 @@ _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG
=
re
.
compile
(
r
'.*task:(\d+)/.*device:TPU:(\d+)$'
)
_DEFAULT_JOB_NAME
=
'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME
=
'coordinator'
_LOCAL_MASTERS
=
(
''
,
'local'
)
# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
# including num_cores and num_hosts.
_TPUSystemMetadata
=
collections
.
namedtuple
(
'_TPUSystemMetadata'
,
[
...
...
@@ -154,3 +158,42 @@ def get_session_config_with_timeout(timeout_in_secs, cluster_def):
config
=
config_pb2
.
ConfigProto
(
operation_timeout_in_ms
=
timeout_in_secs
,
cluster_def
=
cluster_def
)
return
config
def
master_job
(
master
,
cluster_def
):
"""Returns the canonnical job name to use to place TPU computations on.
Args:
master: A `string` representing the TensorFlow master to use.
cluster_def: A ClusterDef object describing the TPU cluster.
Returns:
A string containing the job name, or None if no job should be specified.
Raises:
ValueError: If the user needs to specify a tpu_job_name, because we are
unable to infer the job name automatically, or if the user-specified job
names are inappropriate.
"""
# If the user specifies the tpu_job_name, use that.
if
master
in
_LOCAL_MASTERS
:
return
None
if
(
not
cluster_def
or
not
cluster_def
.
job
):
return
_DEFAULT_JOB_NAME
job_names
=
set
([
job
.
name
for
job
in
cluster_def
.
job
])
if
_DEFAULT_JOB_NAME
in
job_names
:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise
ValueError
(
'Currently, tpu_worker is not an allowed job name.'
)
if
len
(
job_names
)
==
1
:
return
cluster_def
.
job
[
0
].
name
if
len
(
job_names
)
==
2
:
if
_DEFAULT_COORDINATOR_JOB_NAME
in
job_names
:
job_names
.
remove
(
_DEFAULT_COORDINATOR_JOB_NAME
)
return
job_names
.
pop
()
# TODO(b/67716447): Include more sophisticated heuristics.
raise
ValueError
(
'Could not infer TPU job name. Please specify a tpu_job_name as part '
'of your TPUConfig.'
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录