提交 199fe84d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Update sequence categorical columns to new FeatureColumn API.

PiperOrigin-RevId: 225366627
上级 c5002272
......@@ -110,8 +110,8 @@ py_test(
"//tensorflow/python:parsing_ops",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:training",
"//tensorflow/python/feature_column",
"//tensorflow/python/feature_column:feature_column_py",
"//tensorflow/python/feature_column:feature_column_v2_test",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
......
......@@ -203,7 +203,8 @@ def sequence_categorical_column_with_identity(
columns = [watches_embedding]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
input_layer, sequence_length = sequence_input_layer(features, columns)
sequence_feature_layer = SequenceFeatureLayer(columns)
input_layer, sequence_length = sequence_feature_layer(features)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
outputs, state = tf.nn.dynamic_rnn(
......@@ -219,15 +220,17 @@ def sequence_categorical_column_with_identity(
`[0, num_buckets)`, and will replace out-of-range inputs.
Returns:
A `_SequenceCategoricalColumn`.
A `SequenceCategoricalColumn`.
Raises:
ValueError: if `num_buckets` is less than one.
ValueError: if `default_value` is not in range `[0, num_buckets)`.
"""
return fc_old._SequenceCategoricalColumn(
fc_old._categorical_column_with_identity(
key=key, num_buckets=num_buckets, default_value=default_value))
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_identity(
key=key,
num_buckets=num_buckets,
default_value=default_value))
def sequence_categorical_column_with_hash_bucket(
......@@ -247,7 +250,8 @@ def sequence_categorical_column_with_hash_bucket(
columns = [tokens_embedding]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
input_layer, sequence_length = sequence_input_layer(features, columns)
sequence_feature_layer = SequenceFeatureLayer(columns)
input_layer, sequence_length = sequence_feature_layer(features)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
outputs, state = tf.nn.dynamic_rnn(
......@@ -260,15 +264,17 @@ def sequence_categorical_column_with_hash_bucket(
dtype: The type of features. Only string and integer types are supported.
Returns:
A `_SequenceCategoricalColumn`.
A `SequenceCategoricalColumn`.
Raises:
ValueError: `hash_bucket_size` is not greater than 1.
ValueError: `dtype` is neither string nor integer.
"""
return fc_old._SequenceCategoricalColumn(
fc_old._categorical_column_with_hash_bucket(
key=key, hash_bucket_size=hash_bucket_size, dtype=dtype))
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_hash_bucket(
key=key,
hash_bucket_size=hash_bucket_size,
dtype=dtype))
def sequence_categorical_column_with_vocabulary_file(
......@@ -290,7 +296,8 @@ def sequence_categorical_column_with_vocabulary_file(
columns = [states_embedding]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
input_layer, sequence_length = sequence_input_layer(features, columns)
sequence_feature_layer = SequenceFeatureLayer(columns)
input_layer, sequence_length = sequence_feature_layer(features)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
outputs, state = tf.nn.dynamic_rnn(
......@@ -314,7 +321,7 @@ def sequence_categorical_column_with_vocabulary_file(
dtype: The type of features. Only string and integer types are supported.
Returns:
A `_SequenceCategoricalColumn`.
A `SequenceCategoricalColumn`.
Raises:
ValueError: `vocabulary_file` is missing or cannot be opened.
......@@ -323,8 +330,8 @@ def sequence_categorical_column_with_vocabulary_file(
ValueError: `num_oov_buckets` and `default_value` are both specified.
ValueError: `dtype` is neither string nor integer.
"""
return fc_old._SequenceCategoricalColumn(
fc_old._categorical_column_with_vocabulary_file(
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_file(
key=key,
vocabulary_file=vocabulary_file,
vocabulary_size=vocabulary_size,
......@@ -351,7 +358,8 @@ def sequence_categorical_column_with_vocabulary_list(
columns = [colors_embedding]
features = tf.parse_example(..., features=make_parse_example_spec(columns))
input_layer, sequence_length = sequence_input_layer(features, columns)
sequence_feature_layer = SequenceFeatureLayer(columns)
input_layer, sequence_length = sequence_feature_layer(features)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(hidden_size)
outputs, state = tf.nn.dynamic_rnn(
......@@ -375,7 +383,7 @@ def sequence_categorical_column_with_vocabulary_list(
with `default_value`.
Returns:
A `_SequenceCategoricalColumn`.
A `SequenceCategoricalColumn`.
Raises:
ValueError: if `vocabulary_list` is empty, or contains duplicate keys.
......@@ -383,8 +391,8 @@ def sequence_categorical_column_with_vocabulary_list(
ValueError: `num_oov_buckets` and `default_value` are both specified.
ValueError: if `dtype` is not integer or string.
"""
return fc_old._SequenceCategoricalColumn(
fc_old._categorical_column_with_vocabulary_list(
return fc.SequenceCategoricalColumn(
fc.categorical_column_with_vocabulary_list(
key=key,
vocabulary_list=vocabulary_list,
dtype=dtype,
......
......@@ -26,7 +26,7 @@ from tensorflow.contrib.feature_column.python.feature_column import sequence_fea
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column_v2 as sfc
from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_lib as fc
from tensorflow.python.feature_column.feature_column import _LazyBuilder
from tensorflow.python.feature_column.feature_column_v2_test import _TestStateManager
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
......@@ -131,7 +131,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase):
feature_columns=[embedding_column_b, embedding_column_a])
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
self.assertCountEqual(
('sequence_input_layer/aaa_embedding/embedding_weights:0',
'sequence_input_layer/bbb_embedding/embedding_weights:0'),
tuple([v.name for v in global_vars]))
......@@ -223,7 +223,7 @@ class SequenceInputLayerTest(test.TestCase, parameterized.TestCase):
feature_columns=shared_embedding_columns)
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
self.assertCountEqual(
('sequence_input_layer/aaa_bbb_shared_embedding/embedding_weights:0',),
tuple([v.name for v in global_vars]))
with monitored_session.MonitoredSession() as sess:
......@@ -670,6 +670,23 @@ def _assert_sparse_tensor_indices_shape(test_case, expected, actual):
test_case.assertAllEqual(expected.dense_shape, actual.dense_shape)
def _get_sequence_dense_tensor(column, features):
return column.get_sequence_dense_tensor(
fc.FeatureTransformationCache(features), None)
def _get_sequence_dense_tensor_state(column, features):
state_manager = _TestStateManager()
column.create_state(state_manager)
return column.get_sequence_dense_tensor(
fc.FeatureTransformationCache(features), state_manager)
def _get_sparse_tensors(column, features):
return column.get_sparse_tensors(
fc.FeatureTransformationCache(features), None)
class SequenceCategoricalColumnWithIdentityTest(
test.TestCase, parameterized.TestCase):
......@@ -698,7 +715,7 @@ class SequenceCategoricalColumnWithIdentityTest(
expected = sparse_tensor.SparseTensorValue(**expected_args)
column = sfc.sequence_categorical_column_with_identity('aaa', num_buckets=9)
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs})
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
......@@ -737,7 +754,7 @@ class SequenceCategoricalColumnWithHashBucketTest(
column = sfc.sequence_categorical_column_with_hash_bucket(
'aaa', hash_bucket_size=10)
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs})
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
......@@ -790,7 +807,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest(
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size)
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs})
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
......@@ -814,8 +831,7 @@ class SequenceCategoricalColumnWithVocabularyFileTest(
input_placeholder_shape[1] = None
input_placeholder = array_ops.sparse_placeholder(
dtypes.string, shape=input_placeholder_shape)
id_weight_pair = column._get_sparse_tensors(
_LazyBuilder({'aaa': input_placeholder}))
id_weight_pair = _get_sparse_tensors(column, {'aaa': input_placeholder})
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
......@@ -855,7 +871,7 @@ class SequenceCategoricalColumnWithVocabularyListTest(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'))
id_weight_pair = column._get_sparse_tensors(_LazyBuilder({'aaa': inputs}))
id_weight_pair = _get_sparse_tensors(column, {'aaa': inputs})
self.assertIsNone(id_weight_pair.weight_tensor)
with monitored_session.MonitoredSession() as sess:
......@@ -922,13 +938,12 @@ class SequenceEmbeddingColumnTest(
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = fc_old._embedding_column(
categorical_column,
dimension=embedding_dimension,
embedding_column = fc.embedding_column(
categorical_column, dimension=embedding_dimension,
initializer=_initializer)
embedding_lookup, _ = embedding_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': inputs}))
embedding_lookup, _ = _get_sequence_dense_tensor_state(
embedding_column, {'aaa': inputs})
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(
......@@ -961,10 +976,11 @@ class SequenceEmbeddingColumnTest(
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = fc_old._embedding_column(categorical_column, dimension=2)
embedding_column = fc.embedding_column(
categorical_column, dimension=2)
_, sequence_length = embedding_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': inputs}))
_, sequence_length = _get_sequence_dense_tensor_state(
embedding_column, {'aaa': inputs})
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
......@@ -988,10 +1004,11 @@ class SequenceEmbeddingColumnTest(
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
embedding_column = fc_old._embedding_column(categorical_column, dimension=2)
embedding_column = fc.embedding_column(
categorical_column, dimension=2)
_, sequence_length = embedding_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': sparse_input}))
_, sequence_length = _get_sequence_dense_tensor_state(
embedding_column, {'aaa': sparse_input})
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(
......@@ -1058,22 +1075,18 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase):
key='aaa', num_buckets=vocabulary_size)
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns(
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
embedding_lookup_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
_LazyBuilder({
'aaa': sparse_input_a
}))[0]
embedding_lookup_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
_LazyBuilder({
'bbb': sparse_input_b
}))[0]
embedding_lookup_a = _get_sequence_dense_tensor(
shared_embedding_columns[0], {'aaa': sparse_input_a})[0]
embedding_lookup_b = _get_sequence_dense_tensor(
shared_embedding_columns[1], {'bbb': sparse_input_b})[0]
global_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)
self.assertItemsEqual(('embedding_weights:0',),
self.assertItemsEqual(('aaa_bbb_shared_embedding:0',),
tuple([v.name for v in global_vars]))
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(embedding_values, global_vars[0].eval(session=sess))
......@@ -1104,17 +1117,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase):
expected_sequence_length_b = [2, 1]
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns(
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b], dimension=2)
sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
_LazyBuilder({
'aaa': sparse_input_a
}))[1]
sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
_LazyBuilder({
'bbb': sparse_input_b
}))[1]
sequence_length_a = _get_sequence_dense_tensor(
shared_embedding_columns[0], {'aaa': sparse_input_a})[1]
sequence_length_b = _get_sequence_dense_tensor(
shared_embedding_columns[1], {'bbb': sparse_input_b})[1]
with monitored_session.MonitoredSession() as sess:
sequence_length_a = sess.run(sequence_length_a)
......@@ -1155,17 +1164,13 @@ class SequenceSharedEmbeddingColumnTest(test.TestCase):
categorical_column_b = sfc.sequence_categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
shared_embedding_columns = fc.shared_embedding_columns(
shared_embedding_columns = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b], dimension=2)
sequence_length_a = shared_embedding_columns[0]._get_sequence_dense_tensor(
_LazyBuilder({
'aaa': sparse_input_a
}))[1]
sequence_length_b = shared_embedding_columns[1]._get_sequence_dense_tensor(
_LazyBuilder({
'bbb': sparse_input_b
}))[1]
sequence_length_a = _get_sequence_dense_tensor(
shared_embedding_columns[0], {'aaa': sparse_input_a})[1]
sequence_length_b = _get_sequence_dense_tensor(
shared_embedding_columns[1], {'bbb': sparse_input_b})[1]
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(
......@@ -1221,10 +1226,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc_old._indicator_column(categorical_column)
indicator_column = fc.indicator_column(categorical_column)
indicator_tensor, _ = indicator_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': inputs}))
indicator_tensor, _ = _get_sequence_dense_tensor(
indicator_column, {'aaa': inputs})
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(expected, indicator_tensor.eval(session=sess))
......@@ -1253,10 +1258,10 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase):
categorical_column = sfc.sequence_categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc_old._indicator_column(categorical_column)
indicator_column = fc.indicator_column(categorical_column)
_, sequence_length = indicator_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': inputs}))
_, sequence_length = _get_sequence_dense_tensor(
indicator_column, {'aaa': inputs})
with monitored_session.MonitoredSession() as sess:
sequence_length = sess.run(sequence_length)
......@@ -1282,19 +1287,14 @@ class SequenceIndicatorColumnTest(test.TestCase, parameterized.TestCase):
key='aaa', num_buckets=vocabulary_size)
indicator_column = fc.indicator_column(categorical_column)
_, sequence_length = indicator_column._get_sequence_dense_tensor(
_LazyBuilder({'aaa': sparse_input}))
_, sequence_length = _get_sequence_dense_tensor(
indicator_column, {'aaa': sparse_input})
with monitored_session.MonitoredSession() as sess:
self.assertAllEqual(
expected_sequence_length, sequence_length.eval(session=sess))
def _get_sequence_dense_tensor(column, features):
return column.get_sequence_dense_tensor(
fc.FeatureTransformationCache(features), None)
class SequenceNumericColumnTest(test.TestCase, parameterized.TestCase):
def test_defaults(self):
......
......@@ -3111,7 +3111,7 @@ class EmbeddingColumn(
'Suggested fix: Use one of sequence_categorical_column_with_*. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
sparse_tensors = self.categorical_column.get_sequence_sparse_tensors(
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
dense_tensor = self._get_dense_tensor_internal(sparse_tensors,
state_manager)
......@@ -3307,7 +3307,7 @@ class SharedEmbeddingColumn(
'Suggested fix A: If you wish to use input_layer, use a '
'non-sequence categorical_column_with_*. '
'Suggested fix B: If you wish to create sequence input, use '
'sequence_input_layer instead of input_layer. '
'SequenceFeatureLayer instead of FeatureLayer. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
return self._get_dense_tensor_internal(transformation_cache, state_manager)
......@@ -3321,12 +3321,12 @@ class SharedEmbeddingColumn(
raise ValueError(
'In embedding_column: {}. '
'categorical_column must be of type SequenceCategoricalColumn '
'to use sequence_input_layer. '
'to use SequenceFeatureLayer. '
'Suggested fix: Use one of sequence_categorical_column_with_*. '
'Given (type {}): {}'.format(self.name, type(self.categorical_column),
self.categorical_column))
dense_tensor = self.get_dense_tensor_internal(transformation_cache,
state_manager)
dense_tensor = self._get_dense_tensor_internal(transformation_cache,
state_manager)
sparse_tensors = self.categorical_column.get_sparse_tensors(
transformation_cache, state_manager)
sequence_length = fc_old._sequence_length_from_sparse_tensor( # pylint: disable=protected-access
......@@ -4469,8 +4469,8 @@ def _verify_static_batch_size_equality(tensors, columns):
class SequenceCategoricalColumn(
FeatureColumn,
fc_old._CategoricalColumn, # pylint: disable=protected-access
CategoricalColumn,
fc_old._SequenceCategoricalColumn, # pylint: disable=protected-access
collections.namedtuple('SequenceCategoricalColumn',
('categorical_column'))):
"""Represents sequences of categorical data."""
......@@ -4533,7 +4533,7 @@ class SequenceCategoricalColumn(
weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
return CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
def get_sequence_sparse_tensors(self, transformation_cache, state_manager):
def get_sparse_tensors(self, transformation_cache, state_manager):
"""Returns an IdWeightPair.
`IdWeightPair` is a pair of `SparseTensor`s which represents ids and
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册