提交 14d492d7 编写于 作者: B Bruce Fontaine 提交者: TensorFlower Gardener

Add alternative in high level TPU embedding API to not use feature columns,...

Add alternative in high level TPU embedding API to not use feature columns, but to use the mid level API FeatureConfig and TableConfig instead.

PiperOrigin-RevId: 251656574
上级 c6244844
......@@ -50,8 +50,7 @@ def embedding_column(categorical_column,
dimension,
combiner='mean',
initializer=None,
max_sequence_length=0,
partition_strategy='div'):
max_sequence_length=0):
"""TPU embedding_column for `tf.feature_column.embedding_column`.
Note that the interface for TPU embedding_column is different from the non-TPU
......@@ -78,11 +77,6 @@ def embedding_column(categorical_column,
length. Any sequence shorter then this will be padded with 0 embeddings
and any sequence longer will be truncated. This must be positive for
sequence features and 0 for non-sequence features.
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
`"div"` and `"mod"'. If `"mod"` is used, evaluation and exporting the
model to CPU will not work. In order to do this, you must shuffle the
embedding tensors into a single shard.
Returns:
A _TPUEmbeddingColumn.
......@@ -128,8 +122,7 @@ def embedding_column(categorical_column,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
max_sequence_length=max_sequence_length)
# For Embedding column, the initializer is hidden inside the creator Fn, which
# is not accessiable later. So, we attach it to a speicial field. Also note
# that non-TPU Embedding column and non-TPU shared Embedding column handle the
......@@ -143,8 +136,7 @@ def shared_embedding_columns(categorical_columns,
combiner='mean',
initializer=None,
shared_embedding_collection_name=None,
max_sequence_lengths=None,
partition_strategy='div'):
max_sequence_lengths=None):
"""List of dense columns that convert from sparse, categorical input.
Note that the interface for TPU embedding_column is different from the non-TPU
......@@ -177,9 +169,6 @@ def shared_embedding_columns(categorical_columns,
to sequence columns specify the max sequence length for the column. Any
sequence shorter then this will be padded with 0 embeddings and any
sequence longer will be truncated.
partition_strategy: Determines how tensors are sharded on the tpu hosts. See
`tf.nn.safe_embedding_lookup_sparse` for more details. Allowed value are
`"div"` and `"mod"'.
Returns:
A _TPUEmbeddingColumn.
......@@ -249,8 +238,7 @@ def shared_embedding_columns(categorical_columns,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
max_sequence_length=max_sequence_length)
tpu_columns.append(column)
return tpu_columns
......@@ -259,8 +247,7 @@ def shared_embedding_columns(categorical_columns,
class _TPUBaseEmbeddingColumn(object):
"""Base class for TPU Embedding Column."""
def __init__(self, categorical_column, max_sequence_length=0,
partition_strategy='div'):
def __init__(self, categorical_column, max_sequence_length=0):
self._tpu_categorical_column = categorical_column
self._max_sequence_length = max_sequence_length
if (self.is_sequence_column() and max_sequence_length < 1):
......@@ -272,10 +259,6 @@ class _TPUBaseEmbeddingColumn(object):
raise ValueError('Non zero max_seq_length={} specified for non '
'sequence column {}.'.format(max_sequence_length,
categorical_column.name))
self._partition_strategy = partition_strategy
if partition_strategy not in ('mod', 'div'):
raise ValueError('partition_strategy must be one of `mod` or `div`. '
'Received {}.'.format(partition_strategy))
def get_combiner(self):
"""Returns the embedding combiner."""
......@@ -320,9 +303,6 @@ class _TPUBaseEmbeddingColumn(object):
return get_sequence_length_feature_key_name_from_feature_key_name(
self.get_feature_key_name())
def get_partition_strategy(self):
return self._partition_strategy
class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
"""Core Embedding Column."""
......@@ -336,8 +316,7 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0,
partition_strategy='div'):
max_sequence_length=0):
# Note, args ckpt_to_load_from, tensor_name_in_ckpt, max_norm and trainable
# are not supported on TPU. They are solely for matching the signature of
# __new__ of parent class fc._EmbeddingColumn.
......@@ -361,11 +340,9 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0,
partition_strategy='div'):
max_sequence_length=0):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
max_sequence_length=max_sequence_length)
self._key = None
def get_combiner(self):
......@@ -406,18 +383,12 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._EmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
......@@ -434,18 +405,12 @@ class _TPUEmbeddingColumn(_TPUBaseEmbeddingColumn, fc._EmbeddingColumn):
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._EmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._EmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
......@@ -478,8 +443,7 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0,
partition_strategy='div'):
max_sequence_length=0):
return fc._SharedEmbeddingColumn.__new__(
cls,
categorical_column,
......@@ -502,12 +466,10 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
tensor_name_in_ckpt=None,
max_norm=None,
trainable=True,
max_sequence_length=0,
partition_strategy='div'):
max_sequence_length=0):
_TPUBaseEmbeddingColumn.__init__(self, categorical_column,
max_sequence_length=max_sequence_length,
partition_strategy=partition_strategy)
max_sequence_length=max_sequence_length)
self._key = None
def get_combiner(self):
......@@ -548,18 +510,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._SharedEmbeddingColumn._get_dense_tensor(
self, inputs, weight_collections, trainable)
......@@ -577,18 +533,12 @@ class _TPUSharedEmbeddingColumn(_TPUBaseEmbeddingColumn,
def _get_sequence_dense_tensor(
self, inputs, weight_collections=None, trainable=None):
if tpu.under_tpu_inference_context():
if self._partition_strategy == 'mod':
raise NotImplementedError('Export saved model does not support MOD '
'sharded embeddings.')
def host_computation():
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
return tpu.outside_compilation(host_computation)
if _is_running_on_cpu():
if self._partition_strategy == 'mod':
raise NotImplementedError('TPUEmbedding on CPU does not support MOD '
'sharded embeddings.')
return fc._SharedEmbeddingColumn._get_sequence_dense_tensor(
self, inputs, weight_collections, trainable)
......
......@@ -100,12 +100,13 @@ class TableConfig(
class FeatureConfig(
collections.namedtuple(
'FeatureConfig',
['table_id', 'max_sequence_length'])):
['table_id', 'max_sequence_length', 'weight_key'])):
"""Feature configuration."""
def __new__(cls,
table_id,
max_sequence_length=0):
max_sequence_length=0,
weight_key=None):
"""Feature configuration.
Args:
......@@ -114,6 +115,8 @@ class FeatureConfig(
the corresponding maximum sequence length. If the sequence is longer
than this, it will be truncated. If 0, the feature is not a sequence
feature.
weight_key: If using weights for the combiner, this key specifies which
input feature contains the weights.
Returns:
`FeatureConfig`.
......@@ -125,7 +128,8 @@ class FeatureConfig(
raise ValueError('Invalid max_sequence_length {}.'.format(
max_sequence_length))
return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length)
return super(FeatureConfig, cls).__new__(cls, table_id, max_sequence_length,
weight_key)
class EnqueueData(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册