提交 29293fb6 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Support deepcopy in _SparseColumn.

Change: 150488705
上级 2cc1e156
...@@ -329,6 +329,9 @@ class _SparseColumn(_FeatureColumn, ...@@ -329,6 +329,9 @@ class _SparseColumn(_FeatureColumn,
if is_integerized and not dtype.is_integer: if is_integerized and not dtype.is_integer:
raise ValueError("dtype must be an integer if is_integerized is True. " raise ValueError("dtype must be an integer if is_integerized is True. "
"dtype: {}, column_name: {}.".format(dtype, column_name)) "dtype: {}, column_name: {}.".format(dtype, column_name))
if dtype != dtypes.string and not dtype.is_integer:
raise ValueError("dtype must be string or integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
if bucket_size is None and lookup_config is None: if bucket_size is None and lookup_config is None:
raise ValueError("one of bucket_size or lookup_config must be set. " raise ValueError("one of bucket_size or lookup_config must be set. "
...@@ -355,9 +358,14 @@ class _SparseColumn(_FeatureColumn, ...@@ -355,9 +358,14 @@ class _SparseColumn(_FeatureColumn,
raise ValueError("vocab_size must be defined. " raise ValueError("vocab_size must be defined. "
"column_name: {}".format(column_name)) "column_name: {}".format(column_name))
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized, return super(_SparseColumn, cls).__new__(
bucket_size, lookup_config, cls,
combiner, dtype) column_name,
is_integerized=is_integerized,
bucket_size=bucket_size,
lookup_config=lookup_config,
combiner=combiner,
dtype=dtype)
@property @property
def name(self): def name(self):
...@@ -440,20 +448,6 @@ class _SparseColumn(_FeatureColumn, ...@@ -440,20 +448,6 @@ class _SparseColumn(_FeatureColumn,
class _SparseColumnIntegerized(_SparseColumn): class _SparseColumnIntegerized(_SparseColumn):
"""See `sparse_column_with_integerized_feature`.""" """See `sparse_column_with_integerized_feature`."""
def __new__(cls, column_name, bucket_size, combiner="sqrtn",
dtype=dtypes.int64):
if not dtype.is_integer:
raise ValueError("dtype must be an integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnIntegerized, cls).__new__(
cls,
column_name,
is_integerized=True,
bucket_size=bucket_size,
combiner=combiner,
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors): def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion.""" """Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors) input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
...@@ -505,29 +499,13 @@ def sparse_column_with_integerized_feature(column_name, ...@@ -505,29 +499,13 @@ def sparse_column_with_integerized_feature(column_name,
ValueError: dtype is not integer. ValueError: dtype is not integer.
""" """
return _SparseColumnIntegerized( return _SparseColumnIntegerized(
column_name, bucket_size, combiner=combiner, dtype=dtype) column_name, is_integerized=True, bucket_size=bucket_size,
combiner=combiner, dtype=dtype)
class _SparseColumnHashed(_SparseColumn): class _SparseColumnHashed(_SparseColumn):
"""See `sparse_column_with_hash_bucket`.""" """See `sparse_column_with_hash_bucket`."""
def __new__(cls,
column_name,
hash_bucket_size,
combiner="sum",
dtype=dtypes.string):
if dtype != dtypes.string and not dtype.is_integer:
raise ValueError("dtype must be string or integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnHashed, cls).__new__(
cls,
column_name,
bucket_size=hash_bucket_size,
combiner=combiner,
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors): def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion.""" """Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors) input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
...@@ -573,26 +551,16 @@ def sparse_column_with_hash_bucket(column_name, ...@@ -573,26 +551,16 @@ def sparse_column_with_hash_bucket(column_name,
ValueError: hash_bucket_size is not greater than 2. ValueError: hash_bucket_size is not greater than 2.
ValueError: dtype is neither string nor integer. ValueError: dtype is neither string nor integer.
""" """
return _SparseColumnHashed(column_name, hash_bucket_size, combiner, dtype) return _SparseColumnHashed(
column_name,
bucket_size=hash_bucket_size,
combiner=combiner,
dtype=dtype)
class _SparseColumnKeys(_SparseColumn): class _SparseColumnKeys(_SparseColumn):
"""See `sparse_column_with_keys`.""" """See `sparse_column_with_keys`."""
def __new__(
cls, column_name, keys, default_value=-1, combiner="sum",
dtype=dtypes.string):
if (not dtype.is_integer) and (dtype != dtypes.string):
raise TypeError("Only integer and string are currently supported.")
return super(_SparseColumnKeys, cls).__new__(
cls,
column_name,
combiner=combiner,
lookup_config=_SparseIdLookupConfig(
keys=keys, vocab_size=len(keys), default_value=default_value),
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors): def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion.""" """Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors) input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
...@@ -614,7 +582,7 @@ def sparse_column_with_keys( ...@@ -614,7 +582,7 @@ def sparse_column_with_keys(
Args: Args:
column_name: A string defining sparse column name. column_name: A string defining sparse column name.
keys: A list defining vocabulary. Must be castable to `dtype`. keys: A list or tuple defining vocabulary. Must be castable to `dtype`.
default_value: The value to use for out-of-vocabulary feature values. default_value: The value to use for out-of-vocabulary feature values.
Default is -1. Default is -1.
combiner: A string specifying how to reduce if the sparse column is combiner: A string specifying how to reduce if the sparse column is
...@@ -630,38 +598,18 @@ def sparse_column_with_keys( ...@@ -630,38 +598,18 @@ def sparse_column_with_keys(
Returns: Returns:
A _SparseColumnKeys with keys configuration. A _SparseColumnKeys with keys configuration.
""" """
keys = tuple(keys)
return _SparseColumnKeys( return _SparseColumnKeys(
column_name, tuple(keys), default_value=default_value, combiner=combiner, column_name,
lookup_config=_SparseIdLookupConfig(
keys=keys, vocab_size=len(keys), default_value=default_value),
combiner=combiner,
dtype=dtype) dtype=dtype)
class _SparseColumnVocabulary(_SparseColumn): class _SparseColumnVocabulary(_SparseColumn):
"""See `sparse_column_with_vocabulary_file`.""" """See `sparse_column_with_vocabulary_file`."""
def __new__(cls,
column_name,
vocabulary_file,
num_oov_buckets=0,
vocab_size=None,
default_value=-1,
combiner="sum",
dtype=dtypes.string):
if dtype != dtypes.string and not dtype.is_integer:
raise ValueError("dtype must be string or integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnVocabulary, cls).__new__(
cls,
column_name,
combiner=combiner,
lookup_config=_SparseIdLookupConfig(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
vocab_size=vocab_size,
default_value=default_value),
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors): def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion.""" """Handles sparse column to id conversion."""
st = self._get_input_sparse_tensor(columns_to_tensors) st = self._get_input_sparse_tensor(columns_to_tensors)
...@@ -726,10 +674,11 @@ def sparse_column_with_vocabulary_file(column_name, ...@@ -726,10 +674,11 @@ def sparse_column_with_vocabulary_file(column_name,
return _SparseColumnVocabulary( return _SparseColumnVocabulary(
column_name, column_name,
vocabulary_file, lookup_config=_SparseIdLookupConfig(
num_oov_buckets=num_oov_buckets, vocabulary_file=vocabulary_file,
vocab_size=vocab_size, num_oov_buckets=num_oov_buckets,
default_value=default_value, vocab_size=vocab_size,
default_value=default_value),
combiner=combiner, combiner=combiner,
dtype=dtype) dtype=dtype)
......
...@@ -554,6 +554,55 @@ class FeatureColumnTest(test.TestCase): ...@@ -554,6 +554,55 @@ class FeatureColumnTest(test.TestCase):
sparse_result = sess.run(sparse_output) sparse_result = sess.run(sparse_output)
self.assertEquals(expected_shape, list(sparse_result.dense_shape)) self.assertEquals(expected_shape, list(sparse_result.dense_shape))
def testSparseColumnIntegerizedDeepCopy(self):
"""Tests deepcopy of sparse_column_with_integerized_feature."""
column = fc.sparse_column_with_integerized_feature("a", 10)
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(10, column_copy.bucket_size)
self.assertTrue(column_copy.is_integerized)
def testSparseColumnHashBucketDeepCopy(self):
"""Tests deepcopy of sparse_column_with_hash_bucket."""
column = fc.sparse_column_with_hash_bucket("a", 10)
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(10, column_copy.bucket_size)
self.assertFalse(column_copy.is_integerized)
def testSparseColumnKeysDeepCopy(self):
"""Tests deepcopy of sparse_column_with_keys."""
column = fc.sparse_column_with_keys(
"a", keys=["key0", "key1", "key2"])
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(
fc._SparseIdLookupConfig( # pylint: disable=protected-access
keys=("key0", "key1", "key2"),
vocab_size=3,
default_value=-1),
column_copy.lookup_config)
self.assertFalse(column_copy.is_integerized)
def testSparseColumnVocabularyDeepCopy(self):
"""Tests deepcopy of sparse_column_with_vocabulary_file."""
column = fc.sparse_column_with_vocabulary_file(
"a", vocabulary_file="path_to_file", vocab_size=3)
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(
fc._SparseIdLookupConfig( # pylint: disable=protected-access
vocabulary_file="path_to_file",
num_oov_buckets=0,
vocab_size=3,
default_value=-1),
column_copy.lookup_config)
self.assertFalse(column_copy.is_integerized)
def testCreateFeatureSpec(self): def testCreateFeatureSpec(self):
sparse_col = fc.sparse_column_with_hash_bucket( sparse_col = fc.sparse_column_with_hash_bucket(
"sparse_column", hash_bucket_size=100) "sparse_column", hash_bucket_size=100)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册