提交 fdc6752c 编写于 作者: Y Yutaka Leon 提交者: TensorFlower Gardener

Add tf.contrib.lookup.string_to_index and tf.contrib.lookup.index_to_string to...

Add tf.contrib.lookup.string_to_index and tf.contrib.lookup.index_to_string to map strings to IDs and viceversa.
Change: 117303981
上级 e06a4a2e
......@@ -14,6 +14,8 @@
# ==============================================================================
"""Ops for lookup operations.
@@string_to_index
@@index_to_string
@@LookupInterface
@@InitializableLookupTableBase
@@HashTable
......
......@@ -21,7 +21,9 @@ from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
class LookupInterface(object):
......@@ -301,3 +303,116 @@ class KeyValueTensorInitializer(TableInitializerBase):
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
def string_to_index(tensor, mapping, default_value=-1, name=None):
"""Maps `tensor` of strings into `int64` indices based on `mapping`.
This operation converts `tensor` of strings into `int64` indices.
The mapping is initialized from a string `mapping` tensor where each element
is a key and corresponding index within the tensor is the value.
Any entry in the input which does not have a corresponding entry in 'mapping'
(an out-of-vocabulary entry) is assigned the `default_value`
Elements in `mapping` cannot be duplicated, otherwise the initialization
will throw a FailedPreconditionError.
The underlying table must be initialized by calling
`tf.initialize_all_tables.run()` once.
For example:
```python
mapping_strings = t.constant(["emerson", "lake", "palmer")
feats = tf.constant(["emerson", "lake", "and", "palmer"])
ids = tf.contrib.lookup.string_to_index(
feats, mapping=mapping_strings, default_value=-1)
...
tf.initialize_all_tables().run()
ids.eval() ==> [0, 1, -1, 2]
```
Args:
tensor: A 1-D input `Tensor` with the strings to map to indices.
mapping: A 1-D string `Tensor` that specifies the mapping of strings to
indices.
default_value: The `int64` value to use for out-of-vocabulary strings.
Defaults to -1.
name: A name for this op (optional).
Returns:
The mapped indices. It has the same shape and tensor type (dense or sparse)
as `tensor`.
"""
with ops.op_scope([tensor], name, "string_to_index") as scope:
shared_name = ""
keys = ops.convert_to_tensor(mapping, dtypes.string)
vocab_size = array_ops.size(keys)
values = math_ops.cast(math_ops.range(vocab_size), dtypes.int64)
init = KeyValueTensorInitializer(keys,
values,
dtypes.string,
dtypes.int64,
name="table_init")
t = HashTable(init,
default_value,
shared_name=shared_name,
name="hash_table")
return t.lookup(tensor, name=scope)
def index_to_string(tensor, mapping, default_value="UNK", name=None):
"""Maps `tensor` of indices into string values based on `mapping`.
This operation converts `int64` indices into string values. The mapping is
initialized from a string `mapping` tensor where each element is a value and
the corresponding index within the tensor is the key.
Any input which does not have a corresponding index in 'mapping'
(an out-of-vocabulary entry) is assigned the `default_value`
The underlying table must be initialized by calling
`tf.initialize_all_tables.run()` once.
For example:
```python
mapping_string = t.constant(["emerson", "lake", "palmer")
indices = tf.constant([1, 5], tf.int64)
values = tf.contrib.lookup.index_to_string(
indices, mapping=mapping_string, default_value="UNKNOWN")
...
tf.initialize_all_tables().run()
values.eval() ==> ["lake", "UNKNOWN"]
```
Args:
indices: A `int64` `Tensor` with the indices to map to strings.
mapping: A 1-D string `Tensor` that specifies the strings to map from
indices.
default_value: The string value to use for out-of-vocabulary indices.
name: A name for this op (optional).
Returns:
The strings values associated to the indices. The resultant dense
feature value tensor has the same shape as the corresponding `indices`.
"""
with ops.op_scope([tensor], name, "index_to_string") as scope:
shared_name = ""
values = ops.convert_to_tensor(mapping, dtypes.string)
vocab_size = array_ops.size(values)
keys = math_ops.cast(math_ops.range(vocab_size), dtypes.int64)
init = KeyValueTensorInitializer(keys,
values,
dtypes.int64,
dtypes.string,
name="table_init")
t = HashTable(init,
default_value,
shared_name=shared_name,
name="hash_table")
return t.lookup(tensor, name=scope)
......@@ -235,5 +235,81 @@ class HashTableOpTest(tf.test.TestCase):
values), default_val)
class StringToIndexTest(tf.test.TestCase):
def test_string_to_index(self):
with self.test_session():
mapping_strings = tf.constant(["brain", "salad", "surgery"])
feats = tf.constant(["salad", "surgery", "tarkus"])
indices = tf.contrib.lookup.string_to_index(feats,
mapping=mapping_strings)
self.assertRaises(tf.OpError, indices.eval)
tf.initialize_all_tables().run()
self.assertAllEqual((1, 2, -1), indices.eval())
def test_duplicate_entries(self):
with self.test_session():
mapping_strings = tf.constant(["hello", "hello"])
feats = tf.constant(["hello", "hola"])
indices = tf.contrib.lookup.string_to_index(feats,
mapping=mapping_strings)
self.assertRaises(tf.OpError, tf.initialize_all_tables().run)
def test_string_to_index_with_default_value(self):
default_value = -42
with self.test_session():
mapping_strings = tf.constant(["brain", "salad", "surgery"])
feats = tf.constant(["salad", "surgery", "tarkus"])
indices = tf.contrib.lookup.string_to_index(feats,
mapping=mapping_strings,
default_value=default_value)
self.assertRaises(tf.OpError, indices.eval)
tf.initialize_all_tables().run()
self.assertAllEqual((1, 2, default_value), indices.eval())
class IndexToStringTest(tf.test.TestCase):
def test_index_to_string(self):
with self.test_session():
mapping_strings = tf.constant(["brain", "salad", "surgery"])
indices = tf.constant([0, 1, 2, 3], tf.int64)
feats = tf.contrib.lookup.index_to_string(indices,
mapping=mapping_strings)
self.assertRaises(tf.OpError, feats.eval)
tf.initialize_all_tables().run()
self.assertAllEqual(("brain", "salad", "surgery", "UNK"), feats.eval())
def test_duplicate_entries(self):
with self.test_session():
mapping_strings = tf.constant(["hello", "hello"])
indices = tf.constant([0, 1, 4], tf.int64)
feats = tf.contrib.lookup.index_to_string(indices,
mapping=mapping_strings)
tf.initialize_all_tables().run()
self.assertAllEqual(("hello", "hello", "UNK"), feats.eval())
self.assertRaises(tf.OpError, tf.initialize_all_tables().run)
def test_index_to_string_with_default_value(self):
default_value = "NONE"
with self.test_session():
mapping_strings = tf.constant(["brain", "salad", "surgery"])
indices = tf.constant([1, 2, 4], tf.int64)
feats = tf.contrib.lookup.index_to_string(indices,
mapping=mapping_strings,
default_value=default_value)
self.assertRaises(tf.OpError, feats.eval)
tf.initialize_all_tables().run()
self.assertAllEqual(("salad", "surgery", default_value), feats.eval())
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册