提交 848e56c9 编写于 作者: B Bruce Fontaine 提交者: TensorFlower Gardener

Restrict tpu_embedding hosts to just worker jobs. Pass ClusterDef to TpuEmbedding.

PiperOrigin-RevId: 235210894
上级 85538ad4
......@@ -272,13 +272,13 @@ class EmbeddingConfig(object):
"""
def __init__(self, embedding_config_spec, train_batch_size, eval_batch_size,
num_hosts, num_cores, master):
num_hosts, num_cores, run_config):
self._embedding_config_spec = embedding_config_spec
self._train_batch_size = train_batch_size
self._eval_batch_size = eval_batch_size
self._num_hosts = num_hosts
self._num_cores = num_cores
self._master = master
self._run_config = run_config
self._table_to_config_dict, self._feature_to_table_dict = (
get_tpu_embedding_config_from_feature_columns(
......@@ -306,13 +306,19 @@ class EmbeddingConfig(object):
else:
raise ValueError('Mode {} is not supported.'.format(mode))
master = (
self._run_config.evaluation_master
if mode == model_fn_lib.ModeKeys.EVAL else self._run_config.master)
cluster_def = (self._run_config.session_config.cluster_def
if self._run_config.session_config else None)
tpu_embedding_ = tpu_embedding.TPUEmbedding(
self._table_to_config_dict,
self._feature_to_table_dict,
batch_size,
tpu_embedding_mode,
self._master,
master,
self._optimization_parameters,
cluster_def,
)
return tpu_embedding_
......
......@@ -313,7 +313,7 @@ class _InternalTPUContext(object):
if self._use_tpu and self._embedding_config_spec:
embedding_config = _tpu_estimator_embedding.EmbeddingConfig(
self._embedding_config_spec, self._train_batch_size,
self._eval_batch_size, self.num_hosts, self.num_cores, master)
self._eval_batch_size, self.num_hosts, self.num_cores, self.config)
if not embedding_config.has_embedding_tables():
embedding_config = None
self._lazy_embedding_config_dict[master] = embedding_config
......@@ -510,27 +510,10 @@ class _InternalTPUContext(object):
master = (
run_config.evaluation_master
if mode == model_fn_lib.ModeKeys.EVAL else run_config.master)
if master in _LOCAL_MASTERS:
return None
cluster_def = (run_config.session_config.cluster_def
if run_config.session_config else None)
if (not run_config.session_config or
not run_config.session_config.cluster_def.job):
return _DEFAULT_JOB_NAME
cluster_def = run_config.session_config.cluster_def
job_names = set([job.name for job in cluster_def.job])
if _DEFAULT_JOB_NAME in job_names:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise ValueError('Currently, tpu_worker is not an allowed job name.')
if len(job_names) == 1:
return cluster_def.job[0].name
if len(job_names) == 2:
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
return job_names.pop()
# TODO(b/67716447): Include more sophisticated heuristics.
raise ValueError(
'Could not infer TPU job name. Please specify a tpu_job_name as part '
'of your TPUConfig.')
return tpu_system_metadata_lib.master_job(master, cluster_def)
@property
def tpu_host_placement_function(self):
......
......@@ -308,7 +308,8 @@ class TPUEmbedding(object):
batch_size,
mode,
master,
optimization_parameters=None):
optimization_parameters=None,
cluster_def=None):
"""API for using TPU for embedding lookups.
Args:
......@@ -324,6 +325,7 @@ class TPUEmbedding(object):
optimization_parameters: `AdagradParameters`, `AdamParameters`,
`Stochasticgradientdescentparameters`. Must be set in training and must
be `None` in inference.
cluster_def: A ClusterDef object describing the TPU cluster.
Raises:
ValueError: if any input is invalid.
......@@ -341,14 +343,20 @@ class TPUEmbedding(object):
self._batch_size = batch_size
self._master = master
self._cluster_def = cluster_def
self._tpu_system_metadata = (
tpu_system_metadata_lib._query_tpu_system_metadata(self._master)) # pylint: disable=protected-access
tpu_system_metadata_lib._query_tpu_system_metadata( # pylint: disable=protected-access
self._master, cluster_def=self._cluster_def))
if self._tpu_system_metadata.num_cores == 0:
raise ValueError('TPUEmbedding needs TPUs, but master {} does not have '
'TPUs.'.format(self._master))
self._num_hosts = self._tpu_system_metadata.num_hosts
self._hosts = [device.name for device in self._tpu_system_metadata.devices
if 'device:CPU:' in device.name]
master_job_name = tpu_system_metadata_lib.master_job(self._master,
self._cluster_def)
self._hosts = sorted([
device.name for device in self._tpu_system_metadata.devices
if 'device:CPU:' in device.name and (master_job_name is None or
master_job_name in device.name)])
self._num_cores_per_host = self._tpu_system_metadata.num_of_cores_per_host
self._num_cores = self._tpu_system_metadata.num_cores
......
......@@ -34,6 +34,10 @@ _INITIAL_TPU_SYSTEM_TIMEOUT_IN_MS = 300 * 1000 # 5 mins
_TPU_DEVICE_REG = re.compile(r'.*task:(\d+)/.*device:TPU:(\d+)$')
_DEFAULT_JOB_NAME = 'tpu_worker'
_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator'
_LOCAL_MASTERS = ('', 'local')
# _TPUSystemMetadata is used by TPUEstimator to hold TPU configuration,
# including num_cores and num_hosts.
_TPUSystemMetadata = collections.namedtuple('_TPUSystemMetadata', [
......@@ -154,3 +158,42 @@ def get_session_config_with_timeout(timeout_in_secs, cluster_def):
config = config_pb2.ConfigProto(
operation_timeout_in_ms=timeout_in_secs, cluster_def=cluster_def)
return config
def master_job(master, cluster_def):
"""Returns the canonnical job name to use to place TPU computations on.
Args:
master: A `string` representing the TensorFlow master to use.
cluster_def: A ClusterDef object describing the TPU cluster.
Returns:
A string containing the job name, or None if no job should be specified.
Raises:
ValueError: If the user needs to specify a tpu_job_name, because we are
unable to infer the job name automatically, or if the user-specified job
names are inappropriate.
"""
# If the user specifies the tpu_job_name, use that.
if master in _LOCAL_MASTERS:
return None
if (not cluster_def or not cluster_def.job):
return _DEFAULT_JOB_NAME
job_names = set([job.name for job in cluster_def.job])
if _DEFAULT_JOB_NAME in job_names:
# b/37868888 tracks allowing ClusterSpec propagation to reuse job names.
raise ValueError('Currently, tpu_worker is not an allowed job name.')
if len(job_names) == 1:
return cluster_def.job[0].name
if len(job_names) == 2:
if _DEFAULT_COORDINATOR_JOB_NAME in job_names:
job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME)
return job_names.pop()
# TODO(b/67716447): Include more sophisticated heuristics.
raise ValueError(
'Could not infer TPU job name. Please specify a tpu_job_name as part '
'of your TPUConfig.')
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册