diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py index ceed71a49f6c0610b45d0aea91a0dc48664db5a4..2f19dafe1b5db93746627f83b4aed12b58a78c34 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column.py +++ b/tensorflow/contrib/layers/python/layers/feature_column.py @@ -329,6 +329,9 @@ class _SparseColumn(_FeatureColumn, if is_integerized and not dtype.is_integer: raise ValueError("dtype must be an integer if is_integerized is True. " "dtype: {}, column_name: {}.".format(dtype, column_name)) + if dtype != dtypes.string and not dtype.is_integer: + raise ValueError("dtype must be string or integer. " + "dtype: {}, column_name: {}".format(dtype, column_name)) if bucket_size is None and lookup_config is None: raise ValueError("one of bucket_size or lookup_config must be set. " @@ -355,9 +358,14 @@ class _SparseColumn(_FeatureColumn, raise ValueError("vocab_size must be defined. " "column_name: {}".format(column_name)) - return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized, - bucket_size, lookup_config, - combiner, dtype) + return super(_SparseColumn, cls).__new__( + cls, + column_name, + is_integerized=is_integerized, + bucket_size=bucket_size, + lookup_config=lookup_config, + combiner=combiner, + dtype=dtype) @property def name(self): @@ -440,20 +448,6 @@ class _SparseColumn(_FeatureColumn, class _SparseColumnIntegerized(_SparseColumn): """See `sparse_column_with_integerized_feature`.""" - def __new__(cls, column_name, bucket_size, combiner="sqrtn", - dtype=dtypes.int64): - if not dtype.is_integer: - raise ValueError("dtype must be an integer. " - "dtype: {}, column_name: {}".format(dtype, column_name)) - - return super(_SparseColumnIntegerized, cls).__new__( - cls, - column_name, - is_integerized=True, - bucket_size=bucket_size, - combiner=combiner, - dtype=dtype) - def insert_transformed_feature(self, columns_to_tensors): """Handles sparse column to id conversion.""" input_tensor = self._get_input_sparse_tensor(columns_to_tensors) @@ -505,29 +499,13 @@ def sparse_column_with_integerized_feature(column_name, ValueError: dtype is not integer. """ return _SparseColumnIntegerized( - column_name, bucket_size, combiner=combiner, dtype=dtype) + column_name, is_integerized=True, bucket_size=bucket_size, + combiner=combiner, dtype=dtype) class _SparseColumnHashed(_SparseColumn): """See `sparse_column_with_hash_bucket`.""" - def __new__(cls, - column_name, - hash_bucket_size, - combiner="sum", - dtype=dtypes.string): - - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError("dtype must be string or integer. " - "dtype: {}, column_name: {}".format(dtype, column_name)) - - return super(_SparseColumnHashed, cls).__new__( - cls, - column_name, - bucket_size=hash_bucket_size, - combiner=combiner, - dtype=dtype) - def insert_transformed_feature(self, columns_to_tensors): """Handles sparse column to id conversion.""" input_tensor = self._get_input_sparse_tensor(columns_to_tensors) @@ -573,26 +551,16 @@ def sparse_column_with_hash_bucket(column_name, ValueError: hash_bucket_size is not greater than 2. ValueError: dtype is neither string nor integer. """ - return _SparseColumnHashed(column_name, hash_bucket_size, combiner, dtype) + return _SparseColumnHashed( + column_name, + bucket_size=hash_bucket_size, + combiner=combiner, + dtype=dtype) class _SparseColumnKeys(_SparseColumn): """See `sparse_column_with_keys`.""" - def __new__( - cls, column_name, keys, default_value=-1, combiner="sum", - dtype=dtypes.string): - if (not dtype.is_integer) and (dtype != dtypes.string): - raise TypeError("Only integer and string are currently supported.") - - return super(_SparseColumnKeys, cls).__new__( - cls, - column_name, - combiner=combiner, - lookup_config=_SparseIdLookupConfig( - keys=keys, vocab_size=len(keys), default_value=default_value), - dtype=dtype) - def insert_transformed_feature(self, columns_to_tensors): """Handles sparse column to id conversion.""" input_tensor = self._get_input_sparse_tensor(columns_to_tensors) @@ -614,7 +582,7 @@ def sparse_column_with_keys( Args: column_name: A string defining sparse column name. - keys: A list defining vocabulary. Must be castable to `dtype`. + keys: A list or tuple defining vocabulary. Must be castable to `dtype`. default_value: The value to use for out-of-vocabulary feature values. Default is -1. combiner: A string specifying how to reduce if the sparse column is @@ -630,38 +598,18 @@ def sparse_column_with_keys( Returns: A _SparseColumnKeys with keys configuration. """ + keys = tuple(keys) return _SparseColumnKeys( - column_name, tuple(keys), default_value=default_value, combiner=combiner, + column_name, + lookup_config=_SparseIdLookupConfig( + keys=keys, vocab_size=len(keys), default_value=default_value), + combiner=combiner, dtype=dtype) class _SparseColumnVocabulary(_SparseColumn): """See `sparse_column_with_vocabulary_file`.""" - def __new__(cls, - column_name, - vocabulary_file, - num_oov_buckets=0, - vocab_size=None, - default_value=-1, - combiner="sum", - dtype=dtypes.string): - - if dtype != dtypes.string and not dtype.is_integer: - raise ValueError("dtype must be string or integer. " - "dtype: {}, column_name: {}".format(dtype, column_name)) - - return super(_SparseColumnVocabulary, cls).__new__( - cls, - column_name, - combiner=combiner, - lookup_config=_SparseIdLookupConfig( - vocabulary_file=vocabulary_file, - num_oov_buckets=num_oov_buckets, - vocab_size=vocab_size, - default_value=default_value), - dtype=dtype) - def insert_transformed_feature(self, columns_to_tensors): """Handles sparse column to id conversion.""" st = self._get_input_sparse_tensor(columns_to_tensors) @@ -726,10 +674,11 @@ def sparse_column_with_vocabulary_file(column_name, return _SparseColumnVocabulary( column_name, - vocabulary_file, - num_oov_buckets=num_oov_buckets, - vocab_size=vocab_size, - default_value=default_value, + lookup_config=_SparseIdLookupConfig( + vocabulary_file=vocabulary_file, + num_oov_buckets=num_oov_buckets, + vocab_size=vocab_size, + default_value=default_value), combiner=combiner, dtype=dtype) diff --git a/tensorflow/contrib/layers/python/layers/feature_column_test.py b/tensorflow/contrib/layers/python/layers/feature_column_test.py index a35a586b74524a4233a6b0b5d0e658d69e24d364..02bb9b70a4ca5627b3d993aa7e5f695d52901a62 100644 --- a/tensorflow/contrib/layers/python/layers/feature_column_test.py +++ b/tensorflow/contrib/layers/python/layers/feature_column_test.py @@ -554,6 +554,55 @@ class FeatureColumnTest(test.TestCase): sparse_result = sess.run(sparse_output) self.assertEquals(expected_shape, list(sparse_result.dense_shape)) + def testSparseColumnIntegerizedDeepCopy(self): + """Tests deepcopy of sparse_column_with_integerized_feature.""" + column = fc.sparse_column_with_integerized_feature("a", 10) + self.assertEqual("a", column.name) + column_copy = copy.deepcopy(column) + self.assertEqual("a", column_copy.name) + self.assertEqual(10, column_copy.bucket_size) + self.assertTrue(column_copy.is_integerized) + + def testSparseColumnHashBucketDeepCopy(self): + """Tests deepcopy of sparse_column_with_hash_bucket.""" + column = fc.sparse_column_with_hash_bucket("a", 10) + self.assertEqual("a", column.name) + column_copy = copy.deepcopy(column) + self.assertEqual("a", column_copy.name) + self.assertEqual(10, column_copy.bucket_size) + self.assertFalse(column_copy.is_integerized) + + def testSparseColumnKeysDeepCopy(self): + """Tests deepcopy of sparse_column_with_keys.""" + column = fc.sparse_column_with_keys( + "a", keys=["key0", "key1", "key2"]) + self.assertEqual("a", column.name) + column_copy = copy.deepcopy(column) + self.assertEqual("a", column_copy.name) + self.assertEqual( + fc._SparseIdLookupConfig( # pylint: disable=protected-access + keys=("key0", "key1", "key2"), + vocab_size=3, + default_value=-1), + column_copy.lookup_config) + self.assertFalse(column_copy.is_integerized) + + def testSparseColumnVocabularyDeepCopy(self): + """Tests deepcopy of sparse_column_with_vocabulary_file.""" + column = fc.sparse_column_with_vocabulary_file( + "a", vocabulary_file="path_to_file", vocab_size=3) + self.assertEqual("a", column.name) + column_copy = copy.deepcopy(column) + self.assertEqual("a", column_copy.name) + self.assertEqual( + fc._SparseIdLookupConfig( # pylint: disable=protected-access + vocabulary_file="path_to_file", + num_oov_buckets=0, + vocab_size=3, + default_value=-1), + column_copy.lookup_config) + self.assertFalse(column_copy.is_integerized) + def testCreateFeatureSpec(self): sparse_col = fc.sparse_column_with_hash_bucket( "sparse_column", hash_bucket_size=100)