提交 de5f2aa8 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Support integer sparse feature values.

Add some additional values args to name_scope calls.
Change: 149765769
上级 3af39a00
......@@ -579,28 +579,34 @@ def sparse_column_with_hash_bucket(column_name,
class _SparseColumnKeys(_SparseColumn):
"""See `sparse_column_with_keys`."""
def __new__(cls, column_name, keys, default_value=-1, combiner="sum"):
def __new__(
cls, column_name, keys, default_value=-1, combiner="sum",
dtype=dtypes.string):
if (not dtype.is_integer) and (dtype != dtypes.string):
raise TypeError("Only integer and string are currently supported.")
return super(_SparseColumnKeys, cls).__new__(
cls,
column_name,
combiner=combiner,
lookup_config=_SparseIdLookupConfig(
keys=keys, vocab_size=len(keys), default_value=default_value),
dtype=dtypes.string)
dtype=dtype)
def insert_transformed_feature(self, columns_to_tensors):
"""Handles sparse column to id conversion."""
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
table = lookup.string_to_index_table_from_tensor(
mapping=list(self.lookup_config.keys),
table = lookup.index_table_from_tensor(
mapping=tuple(self.lookup_config.keys),
default_value=self.lookup_config.default_value,
dtype=self.dtype,
name="lookup")
columns_to_tensors[self] = table.lookup(input_tensor)
def sparse_column_with_keys(
column_name, keys, default_value=-1, combiner="sum"):
column_name, keys, default_value=-1, combiner="sum", dtype=dtypes.string):
"""Creates a _SparseColumn with keys.
Look up logic is as follows:
......@@ -608,7 +614,7 @@ def sparse_column_with_keys(
Args:
column_name: A string defining sparse column name.
keys: a string list defining vocabulary.
keys: A list defining vocabulary. Must be castable to `dtype`.
default_value: The value to use for out-of-vocabulary feature values.
Default is -1.
combiner: A string specifying how to reduce if the sparse column is
......@@ -619,12 +625,14 @@ def sparse_column_with_keys(
* "mean": do l1 normalization on features in the column
* "sqrtn": do l2 normalization on features in the column
For more information: `tf.embedding_lookup_sparse`.
dtype: Type of features. Only integer and string are supported.
Returns:
A _SparseColumnKeys with keys configuration.
"""
return _SparseColumnKeys(
column_name, tuple(keys), default_value=default_value, combiner=combiner)
column_name, tuple(keys), default_value=default_value, combiner=combiner,
dtype=dtype)
class _SparseColumnVocabulary(_SparseColumn):
......
......@@ -154,7 +154,7 @@ class FeatureColumnTest(test.TestCase):
# a3 is a completely different sparse column with a1 and a2, but since the
# same shared_embedding_name is passed in, a3 will have the same embedding
# as a1 and a2
a3 = fc.sparse_column_with_keys("a3", ["cathy", "tom", "anderson"])
a3 = fc.sparse_column_with_keys("a3", [42, 1, -1000], dtype=dtypes.int32)
e = fc.shared_embedding_columns(
[a3],
dimension=4,
......@@ -446,7 +446,15 @@ class FeatureColumnTest(test.TestCase):
fc.crossed_column(
set([b, fc.real_valued_column("real")]), hash_bucket_size=10000)
def testWeightedSparseColumnDtypes(self):
def testFloat32WeightedSparseInt32ColumnDtypes(self):
ids = fc.sparse_column_with_keys("ids", [42, 1, -1000], dtype=dtypes.int32)
weighted_ids = fc.weighted_sparse_column(ids, "weights")
self.assertDictEqual({
"ids": parsing_ops.VarLenFeature(dtypes.int32),
"weights": parsing_ops.VarLenFeature(dtypes.float32)
}, weighted_ids.config)
def testFloat32WeightedSparseStringColumnDtypes(self):
ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"])
weighted_ids = fc.weighted_sparse_column(ids, "weights")
self.assertDictEqual({
......@@ -454,6 +462,8 @@ class FeatureColumnTest(test.TestCase):
"weights": parsing_ops.VarLenFeature(dtypes.float32)
}, weighted_ids.config)
def testInt32WeightedSparseStringColumnDtypes(self):
ids = fc.sparse_column_with_keys("ids", ["marlo", "omar", "stringer"])
weighted_ids = fc.weighted_sparse_column(ids, "weights", dtype=dtypes.int32)
self.assertDictEqual({
"ids": parsing_ops.VarLenFeature(dtypes.string),
......@@ -465,6 +475,19 @@ class FeatureColumnTest(test.TestCase):
weighted_ids = fc.weighted_sparse_column(
ids, "weights", dtype=dtypes.string)
def testInt32WeightedSparseInt64ColumnDtypes(self):
ids = fc.sparse_column_with_keys("ids", [42, 1, -1000], dtype=dtypes.int64)
weighted_ids = fc.weighted_sparse_column(ids, "weights", dtype=dtypes.int32)
self.assertDictEqual({
"ids": parsing_ops.VarLenFeature(dtypes.int64),
"weights": parsing_ops.VarLenFeature(dtypes.int32)
}, weighted_ids.config)
with self.assertRaisesRegexp(ValueError,
"dtype is not convertible to float"):
weighted_ids = fc.weighted_sparse_column(
ids, "weights", dtype=dtypes.string)
def testRealValuedColumnDtypes(self):
rvc = fc.real_valued_column("rvc")
self.assertDictEqual(
......@@ -547,10 +570,14 @@ class FeatureColumnTest(test.TestCase):
fc.sparse_column_with_hash_bucket(
"sparse_column_for_embedding", hash_bucket_size=10),
dimension=4)
sparse_id_col = fc.sparse_column_with_keys("id_column",
["marlo", "omar", "stringer"])
weighted_id_col = fc.weighted_sparse_column(sparse_id_col,
"id_weights_column")
str_sparse_id_col = fc.sparse_column_with_keys(
"str_id_column", ["marlo", "omar", "stringer"])
int32_sparse_id_col = fc.sparse_column_with_keys(
"int32_id_column", [42, 1, -1000], dtype=dtypes.int32)
int64_sparse_id_col = fc.sparse_column_with_keys(
"int64_id_column", [42, 1, -1000], dtype=dtypes.int64)
weighted_id_col = fc.weighted_sparse_column(str_sparse_id_col,
"str_id_weights_column")
real_valued_col1 = fc.real_valued_column("real_valued_column1")
real_valued_col2 = fc.real_valued_column("real_valued_column2", 5)
real_valued_col3 = fc.real_valued_column(
......@@ -564,18 +591,22 @@ class FeatureColumnTest(test.TestCase):
b = fc.sparse_column_with_hash_bucket("cross_bbb", hash_bucket_size=100)
cross_col = fc.crossed_column(set([a, b]), hash_bucket_size=10000)
feature_columns = set([
sparse_col, embedding_col, weighted_id_col, real_valued_col1,
real_valued_col2, real_valued_col3, bucketized_col1, bucketized_col2,
cross_col
sparse_col, embedding_col, weighted_id_col, int32_sparse_id_col,
int64_sparse_id_col, real_valued_col1, real_valued_col2,
real_valued_col3, bucketized_col1, bucketized_col2, cross_col
])
expected_config = {
"sparse_column":
parsing_ops.VarLenFeature(dtypes.string),
"sparse_column_for_embedding":
parsing_ops.VarLenFeature(dtypes.string),
"id_column":
"str_id_column":
parsing_ops.VarLenFeature(dtypes.string),
"id_weights_column":
"int32_id_column":
parsing_ops.VarLenFeature(dtypes.int32),
"int64_id_column":
parsing_ops.VarLenFeature(dtypes.int64),
"str_id_weights_column":
parsing_ops.VarLenFeature(dtypes.float32),
"real_valued_column1":
parsing_ops.FixedLenFeature(
......@@ -784,11 +815,13 @@ class FeatureColumnTest(test.TestCase):
def testInitCrossedColumnWeightsFromCkpt(self):
sparse_col_1 = fc.sparse_column_with_hash_bucket(
column_name="col_1", hash_bucket_size=4)
sparse_col_2 = fc.sparse_column_with_hash_bucket(
column_name="col_2", hash_bucket_size=4)
sparse_col_2 = fc.sparse_column_with_keys(
column_name="col_2", keys=("foo", "bar", "baz"))
sparse_col_3 = fc.sparse_column_with_keys(
column_name="col_3", keys=(42, 1, -1000), dtype=dtypes.int64)
crossed_col = fc.crossed_column(
columns=[sparse_col_1, sparse_col_2], hash_bucket_size=4)
columns=[sparse_col_1, sparse_col_2, sparse_col_3], hash_bucket_size=4)
input_tensor = sparse_tensor_lib.SparseTensor(
indices=[[0, 0], [1, 1], [2, 2], [3, 3]],
......@@ -804,7 +837,8 @@ class FeatureColumnTest(test.TestCase):
_, col_weights, _ = (
feature_column_ops.weighted_sum_from_feature_columns({
sparse_col_1.name: input_tensor,
sparse_col_2.name: input_tensor
sparse_col_2.name: input_tensor,
sparse_col_3.name: input_tensor
}, [crossed_col], 1))
# Update the weights since default initializer initializes all weights
# to 0.0.
......@@ -827,9 +861,9 @@ class FeatureColumnTest(test.TestCase):
columns=[sparse_col_1, sparse_col_2],
hash_bucket_size=4,
ckpt_to_load_from=checkpoint_path,
tensor_name_in_ckpt=("run_1/col_1_X_col_2/"
tensor_name_in_ckpt=("run_1/col_1_X_col_2_X_col_3/"
"weighted_sum_from_feature_columns/"
"col_1_X_col_2/weights"))
"col_1_X_col_2_X_col_3/weights"))
with variable_scope.variable_scope("run_2"):
# This will initialize the crossed column weights from provided checkpoint
......
......@@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# TODO(ptucker): deprecate string_to_index_table_from_file and
# string_to_index_table_from_tensor 2017-04-10.
"""Ops for lookup operations.
@@string_to_index
@@string_to_index_table_from_file
@@string_to_index_table_from_tensor
@@index_table_from_file
@@index_table_from_tensor
@@index_to_string
@@index_to_string_table_from_file
@@index_to_string_table_from_tensor
......
......@@ -178,8 +178,9 @@ class InitializableLookupTableBase(LookupInterface):
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
with ops.name_scope(name, "%s_Lookup" % self._name,
[self._table_ref]) as scope:
with ops.name_scope(
name, "%s_Lookup" % self._name,
(self._table_ref, key_tensor, self._default_value)) as scope:
# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(
self._table_ref, key_tensor, self._default_value, name=scope)
......@@ -215,7 +216,8 @@ class HashTable(InitializableLookupTableBase):
the table will be immutable.
Args:
initializer: The table initializer to use.
initializer: The table initializer to use. See `HashTable` kernel for
supported key and value types.
default_value: The value to use if a key is missing in the table.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
......@@ -224,7 +226,8 @@ class HashTable(InitializableLookupTableBase):
Returns:
A `HashTable` object.
"""
with ops.name_scope(name, "hash_table", [initializer]) as scope:
with ops.name_scope(
name, "hash_table", (initializer, default_value)) as scope:
# pylint: disable=protected-access
table_ref = gen_data_flow_ops._hash_table(
shared_name=shared_name,
......@@ -301,7 +304,9 @@ class KeyValueTensorInitializer(TableInitializerBase):
key and value data types.
"""
table.check_table_dtypes(self._keys.dtype, self._values.dtype)
with ops.name_scope(self._name, values=[table]) as scope:
with ops.name_scope(
self._name,
values=(table.table_ref, self._keys, self._values)) as scope:
# pylint: disable=protected-access
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
......@@ -422,9 +427,11 @@ class TextFileInitializer(TableInitializerBase):
if key_index == TextFileIndex.LINE_NUMBER and key_dtype != dtypes.int64:
raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
(dtypes.int64, key_dtype))
if key_index == TextFileIndex.WHOLE_LINE and key_dtype != dtypes.string:
raise ValueError("Signature mismatch. Keys must be dtype %s, got %s." %
(dtypes.string, key_dtype))
if ((key_index == TextFileIndex.WHOLE_LINE) and
(not key_dtype.is_integer) and (key_dtype != dtypes.string)):
raise ValueError(
"Signature mismatch. Keys must be integer or string, got %s." %
key_dtype)
if value_index < -2:
raise ValueError("Invalid value index %s." % (value_index))
......@@ -461,7 +468,8 @@ class TextFileInitializer(TableInitializerBase):
key and value data types.
"""
table.check_table_dtypes(self.key_dtype, self.value_dtype)
with ops.name_scope(self._name, "text_file_init", [table]) as scope:
with ops.name_scope(
self._name, "text_file_init", (table.table_ref,)) as scope:
filename = ops.convert_to_tensor(self._filename,
dtypes.string,
name="asset_filepath")
......@@ -539,7 +547,8 @@ class TextFileIdTableInitializer(TextFileInitializer):
value_column_index=TextFileIndex.LINE_NUMBER,
vocab_size=None,
delimiter="\t",
name="text_file_id_table_init"):
name="text_file_id_table_init",
key_dtype=dtypes.string):
"""Constructs an initializer for an string-to-id table from a text file.
It populates a table that its key and value types are string and int64,
......@@ -565,13 +574,14 @@ class TextFileIdTableInitializer(TextFileInitializer):
vocab_size: The number of elements in the file, if known.
delimiter: The delimiter to separate fields in a line.
name: Optional name for the op.
key_dtype: The `key` data type.
Raises:
TypeError: when the filename is empty, or when the table key and value
data types do not match the expected data types.
"""
super(TextFileIdTableInitializer, self).__init__(filename,
dtypes.string,
key_dtype,
key_column_index,
dtypes.int64,
value_column_index,
......@@ -621,6 +631,12 @@ class StrongHashSpec(HasherSpec):
return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
def _as_string(tensor):
if dtypes.string == tensor.dtype.base_dtype:
return tensor
return string_ops.as_string(tensor)
class IdTableWithHashBuckets(LookupInterface):
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
......@@ -663,15 +679,19 @@ class IdTableWithHashBuckets(LookupInterface):
table,
num_oov_buckets,
hasher_spec=FastHashSpec,
name=None):
name=None,
key_dtype=None):
"""Construct a `IdTableWithHashBuckets` object.
Args:
table: Table that maps string to ids.
table: Table that maps `tf.string` or `tf.int64` keys to `tf.int64` ids.
num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets (optional).
name: A name for the operation (optional).
key_dtype: Data type of keys passed to `lookup`. Defaults to
`table.key_dtype` if `table` is specified, otherwise `tf.string`.
Must be string or integer, and must be castable to `table.key_dtype`.
Raises:
ValueError: when `table` in None and `num_oov_buckets` is not positive.
......@@ -682,22 +702,37 @@ class IdTableWithHashBuckets(LookupInterface):
if name:
name = name.rstrip("/")
if table:
table.check_table_dtypes(dtypes.string, dtypes.int64)
if key_dtype is None:
key_dtype = table.key_dtype
supported_table_key_dtypes = (dtypes.int64, dtypes.string)
if table.key_dtype not in supported_table_key_dtypes:
raise TypeError("Invalid key dtype, expected one of %s, but got %s." %
(supported_table_key_dtypes, key_dtype))
if table.key_dtype.is_integer != key_dtype.is_integer:
raise TypeError("Invalid key dtype, expected %s but got %s." %
("integer" if key_dtype.is_integer else "non-integer",
table.key_dtype))
if table.value_dtype != dtypes.int64:
raise TypeError("Invalid value dtype, expected %s but got %s." %
(dtypes.int64, table.value_dtype))
self._table = table
name = name or self._table.name
else:
if num_oov_buckets <= 0:
raise ValueError("oov_buckets must be > 0 if no table is supplied.")
key_dtype = dtypes.string if key_dtype is None else key_dtype
self._table = None
name = name or "hash_bucket"
if (not key_dtype.is_integer) and (dtypes.string != key_dtype):
raise TypeError(
"Invalid key_dtype, expected integer or string, got %s." % key_dtype)
self._num_oov_buckets = num_oov_buckets
if not isinstance(hasher_spec, HasherSpec):
raise TypeError("hasher_spec must be of type HasherSpec, got %s" %
hasher_spec)
self._hasher_spec = hasher_spec
super(IdTableWithHashBuckets, self).__init__(dtypes.string, dtypes.int64,
super(IdTableWithHashBuckets, self).__init__(key_dtype, dtypes.int64,
name.split("/")[-1])
@property
......@@ -748,24 +783,25 @@ class IdTableWithHashBuckets(LookupInterface):
if keys.dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
string_values = keys
values = keys
if isinstance(keys, sparse_tensor.SparseTensor):
string_values = keys.values
values = keys.values
if self._table and (self._table.key_dtype.base_dtype == dtypes.int64):
values = math_ops.to_int64(values)
if self._num_oov_buckets == 0:
ids = self._table.lookup(string_values, name=name)
ids = self._table.lookup(values, name=name)
else:
# TODO(yleon): Consider moving this functionality to its own kernel.
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
buckets = str_to_hash_bucket(
string_values,
_as_string(values),
num_buckets=self._num_oov_buckets,
name="hash_bucket")
if self._table:
ids = self._table.lookup(string_values)
ids = self._table.lookup(values)
buckets = math_ops.add(buckets, self._table.size())
is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
......@@ -776,12 +812,25 @@ class IdTableWithHashBuckets(LookupInterface):
return ids
@deprecated("2017-04-10", "Use `index_table_from_file`.")
def string_to_index_table_from_file(vocabulary_file=None,
num_oov_buckets=0,
vocab_size=None,
default_value=-1,
hasher_spec=FastHashSpec,
name=None):
return index_table_from_file(
vocabulary_file, num_oov_buckets, vocab_size, default_value, hasher_spec,
key_dtype=dtypes.string, name=name)
def index_table_from_file(vocabulary_file=None,
num_oov_buckets=0,
vocab_size=None,
default_value=-1,
hasher_spec=FastHashSpec,
key_dtype=dtypes.string,
name=None):
"""Returns a lookup table that converts a string tensor into int64 IDs.
This operation constructs a lookup table to convert tensor of strings into
......@@ -809,7 +858,7 @@ def string_to_index_table_from_file(vocabulary_file=None,
```python
features = tf.constant(["emerson", "lake", "and", "palmer"])
table = tf.contrib.lookup.string_to_index_table_from_file(
table = tf.contrib.lookup.index_table_from_file(
vocabulary_file="test.txt", num_oov_buckets=1)
ids = table.lookup(features)
...
......@@ -826,6 +875,7 @@ def string_to_index_table_from_file(vocabulary_file=None,
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets.
key_dtype: The `key` data type.
name: A name for this op (optional).
Returns:
......@@ -843,6 +893,8 @@ def string_to_index_table_from_file(vocabulary_file=None,
% num_oov_buckets)
if vocab_size is not None and vocab_size < 1:
raise ValueError("vocab_size must be greater than 0, got %d." % vocab_size)
if (not key_dtype.is_integer) and (dtypes.string != key_dtype.base_dtype):
raise TypeError("Only integer and string keys are supported.")
with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
table = None
......@@ -861,7 +913,9 @@ def string_to_index_table_from_file(vocabulary_file=None,
TextFileIndex.WHOLE_LINE,
TextFileIndex.LINE_NUMBER)
init = TextFileIdTableInitializer(
vocabulary_file, vocab_size=vocab_size, name="table_init")
vocabulary_file, vocab_size=vocab_size,
key_dtype=dtypes.int64 if key_dtype.is_integer else key_dtype,
name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
......@@ -870,16 +924,32 @@ def string_to_index_table_from_file(vocabulary_file=None,
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
name=feat_to_id_scope)
name=feat_to_id_scope,
key_dtype=key_dtype)
return table
@deprecated("2017-04-10", "Use `index_table_from_tensor`.")
def string_to_index_table_from_tensor(mapping,
num_oov_buckets=0,
default_value=-1,
hasher_spec=FastHashSpec,
name=None):
with ops.name_scope(name, "string_to_index") as scope:
mapping = ops.convert_to_tensor(mapping)
if dtypes.string != mapping.dtype.base_dtype:
raise ValueError("string_to_index_table_from_tensor requires string.")
return index_table_from_tensor(
mapping, num_oov_buckets, default_value, hasher_spec, name=scope)
def index_table_from_tensor(mapping,
num_oov_buckets=0,
default_value=-1,
hasher_spec=FastHashSpec,
dtype=dtypes.string,
name=None):
"""Returns a lookup table that converts a string tensor into int64 IDs.
This operation constructs a lookup table to convert tensor of strings into
......@@ -902,7 +972,7 @@ def string_to_index_table_from_tensor(mapping,
```python
mapping_strings = t.constant(["emerson", "lake", "palmer")
table = tf.contrib.lookup.string_to_index_table_from_tensor(
table = tf.contrib.lookup.index_table_from_tensor(
mapping=mapping_strings, num_oov_buckets=1, default_value=-1)
features = tf.constant(["emerson", "lake", "and", "palmer"])
ids = table.lookup(features)
......@@ -913,20 +983,22 @@ def string_to_index_table_from_tensor(mapping,
```
Args:
mapping: A 1-D string `Tensor` that specifies the mapping of strings to
indices.
mapping: A 1-D `Tensor` that specifies the mapping of keys to indices. The
type of this object must be castable to `dtype`.
num_oov_buckets: The number of out-of-vocabulary buckets.
default_value: The value to use for out-of-vocabulary feature values.
Defaults to -1.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets.
assignment of out-of-vocabulary buckets.
dtype: The type of values passed to `lookup`. Only string and integers are
supported.
name: A name for this op (optional).
Returns:
The lookup table to map a string `Tensor` to index `int64` `Tensor`.
The lookup table to map an input `Tensor` to index `int64` `Tensor`.
Raises:
ValueError: `mapping` is invalid.
ValueError: If `mapping` is invalid.
ValueError: If `num_oov_buckets` is negative.
"""
if mapping is None:
......@@ -936,15 +1008,25 @@ def string_to_index_table_from_tensor(mapping,
raise ValueError("num_oov_buckets must be greater or equal than 0, got %d."
% num_oov_buckets)
if (not dtype.is_integer) and (dtypes.string != dtype.base_dtype):
raise TypeError("Only integer and string keys are supported.")
with ops.name_scope(name, "string_to_index") as feat_to_id_scope:
keys = ops.convert_to_tensor(mapping, dtypes.string)
keys = ops.convert_to_tensor(mapping)
if keys.dtype.is_integer != dtype.is_integer:
raise ValueError("Expected %s, got %s." % (
"integer" if dtype.is_integer else "non-integer", keys.dtype))
if (not dtype.is_integer) and (keys.dtype.base_dtype != dtype):
raise ValueError("Expected %s, got %s." % (dtype, keys.dtype))
num_elements = array_ops.size(keys)
values = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
values = math_ops.to_int64(math_ops.range(num_elements))
shared_name = ""
with ops.name_scope(None, "hash_table") as hash_table_scope:
table_keys = math_ops.to_int64(keys) if keys.dtype.is_integer else keys
init = KeyValueTensorInitializer(
keys, values, dtypes.string, dtypes.int64, name="table_init")
table_keys, values, table_keys.dtype.base_dtype, dtypes.int64,
name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
if num_oov_buckets:
......@@ -952,14 +1034,15 @@ def string_to_index_table_from_tensor(mapping,
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
name=feat_to_id_scope)
name=feat_to_id_scope,
key_dtype=dtype)
return table
@deprecated(
"2017-01-07", "This op will be removed after the deprecation date. "
"Please switch to string_to_index_table_from_tensor and call the lookup "
"Please switch to index_table_from_tensor and call the lookup "
"method of the returned table.")
def string_to_index(tensor, mapping, default_value=-1, name=None):
"""Maps `tensor` of strings into `int64` indices based on `mapping`.
......@@ -1002,7 +1085,7 @@ def string_to_index(tensor, mapping, default_value=-1, name=None):
The mapped indices. It has the same shape and tensor type (dense or sparse)
as `tensor`.
"""
table = string_to_index_table_from_tensor(
table = index_table_from_tensor(
mapping=mapping, default_value=default_value, name=name)
return table.lookup(tensor)
......@@ -1135,7 +1218,7 @@ def index_to_string_table_from_tensor(mapping, default_value="UNK", name=None):
with ops.name_scope(name, "index_to_string") as scope:
values = ops.convert_to_tensor(mapping, dtypes.string)
num_elements = array_ops.size(values)
keys = math_ops.cast(math_ops.range(num_elements), dtypes.int64)
keys = math_ops.to_int64(math_ops.range(num_elements))
shared_name = ""
init = KeyValueTensorInitializer(
......@@ -1306,7 +1389,7 @@ class MutableHashTable(LookupInterface):
(self._key_dtype, keys.dtype))
with ops.name_scope(name, "%s_lookup_table_find" % self._name,
[self._table_ref, keys]) as name:
(self._table_ref, keys, self._default_value)) as name:
# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(self._table_ref,
keys,
......
......@@ -205,12 +205,12 @@ class TextFileLineIterator
return;
}
}
status_ = SetValue(line, tokens, key_index_, key_.dtype(), &key_);
status_ = SetValue(line, tokens, key_index_, &key_);
if (!status_.ok()) {
valid_ = false;
return;
}
status_ = SetValue(line, tokens, value_index_, value_.dtype(), &value_);
status_ = SetValue(line, tokens, value_index_, &value_);
if (!status_.ok()) {
valid_ = false;
return;
......@@ -247,17 +247,14 @@ class TextFileLineIterator
// Set the corresponding value from line or tokens based on 'index' into the
// tensor 't'. The value is transformed to the given data type 'dtype'.
Status SetValue(const string& line, const std::vector<string>& tokens,
int64 index, DataType dtype, Tensor* tensor) {
int64 index, Tensor* tensor) {
if (index == kLineNumber) {
tensor->flat<int64>()(0) = next_id_;
return Status::OK();
}
if (index == kWholeLine) {
tensor->flat<string>()(0) = line;
return Status::OK();
}
const string& token = tokens[index];
switch (tensor->dtype()) {
const string& token = (index == kWholeLine) ? line : tokens[index];
const DataType& dtype = tensor->dtype();
switch (dtype) {
case DT_INT32: {
int32 value;
if (!strings::safe_strto32(token.c_str(), &value)) {
......@@ -317,26 +314,28 @@ Status InitializeTableFromTextFile(const string& filename, int64 vocab_size,
"Key index for line number requires table key dtype of int64, got ",
table->key_dtype());
}
if (key_index == kWholeLine && table->key_dtype() != DT_STRING) {
const DataType& key_dtype = table->key_dtype();
const DataType& value_dtype = table->value_dtype();
if (key_index == kWholeLine && !DataTypeIsInteger(key_dtype) &&
key_dtype != DT_STRING) {
return errors::InvalidArgument(
"Key index for whole line requires table key dtype of string, got ",
"Key index for whole line requires string or integer table key, got ",
table->key_dtype());
}
if (value_index == kLineNumber && table->value_dtype() != DT_INT64) {
if (value_index == kLineNumber && value_dtype != DT_INT64) {
return errors::InvalidArgument(
"Value index for line number requires table value dtype of int64, got ",
table->value_dtype());
}
if (value_index == kWholeLine && table->value_dtype() != DT_STRING) {
if (value_index == kWholeLine && value_dtype != DT_STRING) {
return errors::InvalidArgument(
"Value index for whole line requires table value dtype of string, got ",
table->value_dtype());
}
TextFileLineIterator iter;
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter,
table->key_dtype(), key_index,
table->value_dtype(), value_index, env));
TF_RETURN_IF_ERROR(iter.Init(filename, vocab_size, delimiter, key_dtype,
key_index, value_dtype, value_index, env));
// For initialization from files, ignore if the table is already
// initialized. The table shared name should contain the filename to
// avoid trying to initialize the same table from the same file at the same
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册