提交 29293fb6 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Support deepcopy in _SparseColumn.

Change: 150488705
上级 2cc1e156
......@@ -329,6 +329,9 @@ class _SparseColumn(_FeatureColumn,
if is_integerized and not dtype.is_integer:
raise ValueError("dtype must be an integer if is_integerized is True. "
"dtype: {}, column_name: {}.".format(dtype, column_name))
if dtype != dtypes.string and not dtype.is_integer:
raise ValueError("dtype must be string or integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
if bucket_size is None and lookup_config is None:
raise ValueError("one of bucket_size or lookup_config must be set. "
......@@ -355,9 +358,14 @@ class _SparseColumn(_FeatureColumn,
raise ValueError("vocab_size must be defined. "
"column_name: {}".format(column_name))
return super(_SparseColumn, cls).__new__(cls, column_name, is_integerized,
bucket_size, lookup_config,
combiner, dtype)
return super(_SparseColumn, cls).__new__(
cls,
column_name,
is_integerized=is_integerized,
bucket_size=bucket_size,
lookup_config=lookup_config,
combiner=combiner,
dtype=dtype)
@property
def name(self):
......@@ -440,20 +448,6 @@ class _SparseColumn(_FeatureColumn,
class _SparseColumnIntegerized(_SparseColumn):
"""See `sparse_column_with_integerized_feature`."""
def __new__(cls, column_name, bucket_size, combiner="sqrtn",
dtype=dtypes.int64):
if not dtype.is_integer:
raise ValueError("dtype must be an integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnIntegerized, cls).__new__(
cls,
column_name,
is_integerized=True,
bucket_size=bucket_size,
combiner=combiner,
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
......@@ -505,29 +499,13 @@ def sparse_column_with_integerized_feature(column_name,
ValueError: dtype is not integer.
"""
return _SparseColumnIntegerized(
column_name, bucket_size, combiner=combiner, dtype=dtype)
column_name, is_integerized=True, bucket_size=bucket_size,
combiner=combiner, dtype=dtype)
class _SparseColumnHashed(_SparseColumn):
"""See `sparse_column_with_hash_bucket`."""
def __new__(cls,
column_name,
hash_bucket_size,
combiner="sum",
dtype=dtypes.string):
if dtype != dtypes.string and not dtype.is_integer:
raise ValueError("dtype must be string or integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnHashed, cls).__new__(
cls,
column_name,
bucket_size=hash_bucket_size,
combiner=combiner,
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
......@@ -573,26 +551,16 @@ def sparse_column_with_hash_bucket(column_name,
ValueError: hash_bucket_size is not greater than 2.
ValueError: dtype is neither string nor integer.
"""
return _SparseColumnHashed(column_name, hash_bucket_size, combiner, dtype)
return _SparseColumnHashed(
column_name,
bucket_size=hash_bucket_size,
combiner=combiner,
dtype=dtype)
class _SparseColumnKeys(_SparseColumn):
"""See `sparse_column_with_keys`."""
def __new__(
cls, column_name, keys, default_value=-1, combiner="sum",
dtype=dtypes.string):
if (not dtype.is_integer) and (dtype != dtypes.string):
raise TypeError("Only integer and string are currently supported.")
return super(_SparseColumnKeys, cls).__new__(
cls,
column_name,
combiner=combiner,
lookup_config=_SparseIdLookupConfig(
keys=keys, vocab_size=len(keys), default_value=default_value),
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
......@@ -614,7 +582,7 @@ def sparse_column_with_keys(
Args:
column_name: A string defining sparse column name.
keys: A list defining vocabulary. Must be castable to `dtype`.
keys: A list or tuple defining vocabulary. Must be castable to `dtype`.
default_value: The value to use for out-of-vocabulary feature values.
Default is -1.
combiner: A string specifying how to reduce if the sparse column is
......@@ -630,38 +598,18 @@ def sparse_column_with_keys(
Returns:
A _SparseColumnKeys with keys configuration.
"""
keys = tuple(keys)
return _SparseColumnKeys(
column_name, tuple(keys), default_value=default_value, combiner=combiner,
column_name,
lookup_config=_SparseIdLookupConfig(
keys=keys, vocab_size=len(keys), default_value=default_value),
combiner=combiner,
dtype=dtype)
class _SparseColumnVocabulary(_SparseColumn):
"""See `sparse_column_with_vocabulary_file`."""
def __new__(cls,
column_name,
vocabulary_file,
num_oov_buckets=0,
vocab_size=None,
default_value=-1,
combiner="sum",
dtype=dtypes.string):
if dtype != dtypes.string and not dtype.is_integer:
raise ValueError("dtype must be string or integer. "
"dtype: {}, column_name: {}".format(dtype, column_name))
return super(_SparseColumnVocabulary, cls).__new__(
cls,
column_name,
combiner=combiner,
lookup_config=_SparseIdLookupConfig(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
vocab_size=vocab_size,
default_value=default_value),
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
st = self._get_input_sparse_tensor(columns_to_tensors)
......@@ -726,10 +674,11 @@ def sparse_column_with_vocabulary_file(column_name,
return _SparseColumnVocabulary(
column_name,
vocabulary_file,
num_oov_buckets=num_oov_buckets,
vocab_size=vocab_size,
default_value=default_value,
lookup_config=_SparseIdLookupConfig(
vocabulary_file=vocabulary_file,
num_oov_buckets=num_oov_buckets,
vocab_size=vocab_size,
default_value=default_value),
combiner=combiner,
dtype=dtype)
......
......@@ -554,6 +554,55 @@ class FeatureColumnTest(test.TestCase):
sparse_result = sess.run(sparse_output)
self.assertEquals(expected_shape, list(sparse_result.dense_shape))
def testSparseColumnIntegerizedDeepCopy(self):
"""Tests deepcopy of sparse_column_with_integerized_feature."""
column = fc.sparse_column_with_integerized_feature("a", 10)
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(10, column_copy.bucket_size)
self.assertTrue(column_copy.is_integerized)
def testSparseColumnHashBucketDeepCopy(self):
"""Tests deepcopy of sparse_column_with_hash_bucket."""
column = fc.sparse_column_with_hash_bucket("a", 10)
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(10, column_copy.bucket_size)
self.assertFalse(column_copy.is_integerized)
def testSparseColumnKeysDeepCopy(self):
"""Tests deepcopy of sparse_column_with_keys."""
column = fc.sparse_column_with_keys(
"a", keys=["key0", "key1", "key2"])
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(
fc._SparseIdLookupConfig( # pylint: disable=protected-access
keys=("key0", "key1", "key2"),
vocab_size=3,
default_value=-1),
column_copy.lookup_config)
self.assertFalse(column_copy.is_integerized)
def testSparseColumnVocabularyDeepCopy(self):
"""Tests deepcopy of sparse_column_with_vocabulary_file."""
column = fc.sparse_column_with_vocabulary_file(
"a", vocabulary_file="path_to_file", vocab_size=3)
self.assertEqual("a", column.name)
column_copy = copy.deepcopy(column)
self.assertEqual("a", column_copy.name)
self.assertEqual(
fc._SparseIdLookupConfig( # pylint: disable=protected-access
vocabulary_file="path_to_file",
num_oov_buckets=0,
vocab_size=3,
default_value=-1),
column_copy.lookup_config)
self.assertFalse(column_copy.is_integerized)
def testCreateFeatureSpec(self):
sparse_col = fc.sparse_column_with_hash_bucket(
"sparse_column", hash_bucket_size=100)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册