提交 5de7640b 编写于 作者: Y Yutaka Leon 提交者: TensorFlower Gardener

Add IdTableWithHashBuckets to assign out-of-vocabulary hash buckets for terms...

Add IdTableWithHashBuckets to assign out-of-vocabulary hash buckets for terms not covered by the wrapped lookup table.
Change: 141487850
上级 34947dd1
......@@ -18,14 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import functools
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.training.saver import BaseSaverBuilder
from tensorflow.python.util import compat
class LookupInterface(object):
......@@ -71,7 +77,7 @@ class LookupInterface(object):
"""Looks up `keys` in a table, outputs the corresponding values."""
raise NotImplementedError
def _check_table_dtypes(self, key_dtype, value_dtype):
def check_table_dtypes(self, key_dtype, value_dtype):
"""Check that the given key_dtype and value_dtype matches the table dtypes.
Args:
......@@ -293,14 +299,14 @@ class KeyValueTensorInitializer(TableInitializerBase):
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
# pylint: disable=protected-access
table._check_table_dtypes(self._keys.dtype, self._values.dtype)
table.check_table_dtypes(self._keys.dtype, self._values.dtype)
with ops.name_scope(self._name, values=[table]) as scope:
# pylint: disable=protected-access
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
self._values,
name=scope)
# pylint: enable=protected-access
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
......@@ -453,12 +459,12 @@ class TextFileInitializer(TableInitializerBase):
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
# pylint: disable=protected-access
table._check_table_dtypes(self.key_dtype, self.value_dtype)
table.check_table_dtypes(self.key_dtype, self.value_dtype)
with ops.name_scope(self._name, "text_file_init", [table]) as scope:
filename = ops.convert_to_tensor(self._filename,
dtypes.string,
name="asset_filepath")
# pylint: disable=protected-access
init_op = gen_data_flow_ops._initialize_table_from_text_file(
table.table_ref,
filename,
......@@ -467,7 +473,7 @@ class TextFileInitializer(TableInitializerBase):
-1 if self._vocab_size is None else self._vocab_size,
self._delimiter,
name=scope)
# pylint: enable=protected-access
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
ops.add_to_collection(ops.GraphKeys.ASSET_FILEPATHS, filename)
return init_op
......@@ -573,6 +579,202 @@ class TextFileIdTableInitializer(TextFileInitializer):
name=name)
class HasherSpec(collections.namedtuple("HasherSpec", ["hasher", "key"])):
"""A structure for the spec of the hashing function to use for hash buckets.
`hasher` is the name of the hashing function to use (eg. "fasthash",
"stronghash").
`key` is optional and specify the key to use for the hash function if
supported, currently only used by a strong hash.
Fields:
hasher: The hasher name to use.
key: The key to be used by the hashing function, if required.
"""
__slots__ = ()
FastHashSpec = HasherSpec("fasthash", None)
class StrongHashSpec(HasherSpec):
"""A structure to specify a key of the strong keyed hash spec.
The strong hash requires a `key`, which is a list of 2 unsigned integer
numbers. These should be non-zero; random numbers generated from random.org
would be a fine choice.
Fields:
key: The key to be used by the keyed hashing function.
"""
__slots__ = ()
def __new__(cls, key):
if len(key) != 2:
raise ValueError("key must have size 2, got %s." % len(key))
if not isinstance(key[0], compat.integral_types) or not isinstance(
key[1], compat.integral_types):
raise TypeError("Invalid key %s. Must be unsigned integer values." % key)
return super(cls, StrongHashSpec).__new__(cls, "stronghash", key)
class IdTableWithHashBuckets(LookupInterface):
"""String to Id table wrapper that assigns out-of-vocabulary keys to buckets.
For example, if an instance of `IdTableWithHashBuckets` is initialized with a
string-to-id table that maps:
- emerson -> 0
- lake -> 1
- palmer -> 2
The `IdTableWithHashBuckets` object will performs the following mapping:
- emerson -> 0
- lake -> 1
- palmer -> 2
- <other term> -> bucket id between 3 and 3 + num_oov_buckets, calculated by:
hash(<term>) % num_oov_buckets + vocab_size
If input_tensor is ["emerson", "lake", "palmer", "king", "crimson"],
the lookup result is [0, 1, 2, 4, 7]
If `table` is None, only out-of-vocabulary buckets are used.
Example usage:
```python
num_oov_buckets = 3
input_tensor = tf.constant(["emerson", "lake", "palmer", "king", "crimnson"])
table = tf.IdTableWithHashBuckets(
tf.HashTable(tf.TextFileIdTableInitializer(filename), default_value),
num_oov_buckets)
out = table.lookup(input_tensor).
table.init.run()
print out.eval()
```
The hash function used for generating out-of-vocabulary buckets ID is handled
by `hasher_spec`.
"""
def __init__(self,
table,
num_oov_buckets,
hasher_spec=FastHashSpec,
name=None):
"""Construct a `IdTableWithHashBuckets` object.
Args:
table: Table that maps string to ids.
num_oov_buckets: Number of buckets to use for out-of-vocabulary keys.
hasher_spec: A `HasherSpec` to specify the hash function to use for
assignation of out-of-vocabulary buckets (optional).
name: A name for the operation (optional).
Raises:
ValueError: when `table` in None and `num_oov_buckets` is not positive.
TypeError: when `hasher_spec` is invalid.
"""
# If a name ends with a '/' it is a "name scope", remove all trailing '/'
# characters to use as table name.
if name:
name = name.rstrip("/")
if table:
table.check_table_dtypes(dtypes.string, dtypes.int64)
self._table = table
name = name or self._table.name
else:
if num_oov_buckets <= 0:
raise ValueError("oov_buckets must be > 0 if no table is supplied.")
self._table = None
name = name or "hash_bucket"
self._num_oov_buckets = num_oov_buckets
if not isinstance(hasher_spec, HasherSpec):
raise TypeError("hasher_spec must be of type HasherSpec, got %s" %
hasher_spec)
self._hasher_spec = hasher_spec
super(IdTableWithHashBuckets, self).__init__(dtypes.string, dtypes.int64,
name.split("/")[-1])
@property
def init(self):
"""The table initialization op."""
if self._table:
return self._table.init
with ops.name_scope(None, "init"):
return control_flow_ops.no_op()
def size(self, name=None):
"""Compute the number of elements in this table."""
with ops.name_scope(name, "%s_Size" % self.name) as scope:
if self._table:
tsize = self._table.size(scope)
else:
tsize = ops.convert_to_tensor(0, dtype=dtypes.int64)
return tsize + self._num_oov_buckets
def _get_string_to_hash_bucket_fn(self, hasher_spec):
"""Returns the string_to_hash_bucket op to use based on `hasher_spec`."""
if not isinstance(hasher_spec, HasherSpec):
raise TypeError("hasher_spec must be of type HasherSpec %s" % hasher_spec)
if hasher_spec.hasher == "fasthash":
return string_ops.string_to_hash_bucket_fast
if hasher_spec.hasher == "legacy":
return string_ops.string_to_hash_bucket
if hasher_spec.hasher == "stronghash":
return functools.partial(
string_ops.string_to_hash_bucket_strong, key=hasher_spec.key)
raise ValueError("Unknown hasher %s" % hasher_spec.hasher)
def lookup(self, keys, name=None):
"""Looks up `keys` in the table, outputs the corresponding values.
It assigns out-of-vocabulary keys to buckets based in their hashes.
Args:
keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
name: Optional name for the op.
Returns:
A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
Raises:
TypeError: when `keys` doesn't match the table key data type.
"""
if keys.dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
string_values = keys
if isinstance(keys, sparse_tensor.SparseTensor):
string_values = keys.values
if self._num_oov_buckets == 0:
ids = self._table.lookup(string_values, name=name)
else:
# TODO(yleon): Consider moving this functionality to its own kernel.
with ops.name_scope(name, "%s_Lookup" % self.name) as scope:
str_to_hash_bucket = self._get_string_to_hash_bucket_fn(
self._hasher_spec)
buckets = str_to_hash_bucket(
string_values,
num_buckets=self._num_oov_buckets,
name="hash_bucket")
if self._table:
ids = self._table.lookup(string_values)
buckets = math_ops.add(buckets, self._table.size())
is_id_non_default = math_ops.not_equal(ids, self._table.default_value)
ids = array_ops.where(is_id_non_default, ids, buckets, name=scope)
else:
ids = buckets
if isinstance(keys, sparse_tensor.SparseTensor):
return sparse_tensor.SparseTensor(keys.indices, ids, keys.dense_shape)
return ids
def string_to_index(tensor, mapping, default_value=-1, name=None):
"""Maps `tensor` of strings into `int64` indices based on `mapping`.
......@@ -829,7 +1031,7 @@ class MutableHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
self._check_table_dtypes(keys.dtype, values.dtype)
self.check_table_dtypes(keys.dtype, values.dtype)
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
# pylint: disable=protected-access
......@@ -1029,7 +1231,7 @@ class MutableDenseHashTable(LookupInterface):
TypeError: when `keys` or `values` doesn't match the table data
types.
"""
self._check_table_dtypes(keys.dtype, values.dtype)
self.check_table_dtypes(keys.dtype, values.dtype)
with ops.name_scope(name, "%s_lookup_table_insert" % self._name,
[self._table_ref, keys, values]) as name:
# pylint: disable=protected-access
......
......@@ -23,6 +23,8 @@ import numpy as np
import six
import tensorflow as tf
from tensorflow.python.framework import test_util
class HashTableOpTest(tf.test.TestCase):
......@@ -1495,5 +1497,249 @@ class InitializeTableFromFileOpTest(tf.test.TestCase):
self.assertEquals(vocab_size, table.size().eval())
class IdTableWithHashBucketsTest(tf.test.TestCase):
def _createVocabFile(self, basename):
vocabulary_file = os.path.join(self.get_temp_dir(), basename)
with open(vocabulary_file, "w") as f:
f.write("\n".join(["brain", "salad", "surgery"]) + "\n")
return vocabulary_file
def testIdTableWithHashBucketsInit(self):
vocab_file = self._createVocabFile("feat_to_id_3.txt")
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
table = tf.contrib.lookup.IdTableWithHashBuckets(
tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value),
oov_buckets)
table.init.run()
input_string = tf.constant(["brain", "salad", "surgery", "UNK"])
out = table.lookup(input_string)
self.assertAllEqual([0, 1, 2, 3], out.eval())
self.assertEquals(vocab_size + oov_buckets, table.size().eval())
def testIdTableWithOnlyHashBucket(self):
with self.test_session():
oov_buckets = 5
# Set a table that only uses hash buckets, for each input value returns
# an id calculated by fingerprint("input") mod oov_buckets.
table = tf.contrib.lookup.IdTableWithHashBuckets(None, oov_buckets)
table.init.run()
input_string = tf.constant(["brain", "salad", "surgery"])
out = table.lookup(input_string)
self.assertAllEqual(
[
3, # fingerprint("brain") mod 5.
1, # fingerprint("salad") mod 5.
4 # fingerprint("surgery") mod 5
],
out.eval())
self.assertEquals(oov_buckets, table.size().eval())
def testIdTableWithHashBucketsWithMultipleInitializers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.test_session() as sess:
default_value = -1
vocab_size = 3
oov_buckets = 3
vocab_table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value)
table1 = tf.contrib.lookup.IdTableWithHashBuckets(
vocab_table,
oov_buckets,
hasher_spec=tf.contrib.lookup.FastHashSpec,
name="table1")
table2 = tf.contrib.lookup.IdTableWithHashBuckets(
vocab_table,
oov_buckets,
hasher_spec=tf.contrib.lookup.StrongHashSpec((1, 2)),
name="table2")
tf.initialize_all_tables().run()
input_string = tf.constant(["fruit", "brain", "salad", "surgery", "UNK"])
out1 = table1.lookup(input_string)
out2 = table2.lookup(input_string)
out1, out2 = sess.run([out1, out2])
self.assertAllEqual([5, 0, 1, 2, 5], out1)
self.assertAllEqual([5, 0, 1, 2, 3], out2)
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
test_util.assert_ops_in_graph({
"table1_Lookup/hash_bucket": "StringToHashBucketFast",
"table2_Lookup/hash_bucket": "StringToHashBucketStrong",
}, sess.graph)
def testIdTableWithHashBucketsInitializationAcrossSessions(self):
vocab_file = self._createVocabFile("feat_to_id_5.txt")
shared_name = "across-sessions"
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
table1 = tf.contrib.lookup.IdTableWithHashBuckets(
tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value,
shared_name=shared_name),
oov_buckets)
table1.init.run()
input_string_1 = tf.constant(["brain", "salad", "surgery", "UNK"])
out1 = table1.lookup(input_string_1)
self.assertAllEqual([0, 1, 2, 3], out1.eval())
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
# Underlying lookup table already initialized in previous session.
# No need to call table2.init.run()
table2 = tf.contrib.lookup.IdTableWithHashBuckets(
tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value,
shared_name=shared_name),
oov_buckets)
input_string_2 = tf.constant(["fruit", "salad", "UNK"])
out2 = table2.lookup(input_string_2)
self.assertAllEqual([3, 1, 3], out2.eval())
self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
def testIdTableWithHashBucketsWithMultipleInitializersDifferentDefault(self):
vocab_file = self._createVocabFile("feat_to_id_6.txt")
with self.test_session() as sess:
default_value1 = -1
vocab_size = 3
oov_buckets = 0
table1 = tf.contrib.lookup.IdTableWithHashBuckets(
tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value1),
oov_buckets)
default_value2 = -2
table2 = tf.contrib.lookup.IdTableWithHashBuckets(
tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value2),
oov_buckets)
tf.initialize_all_tables().run()
input_string_1 = tf.constant(["brain", "salad", "surgery", "UNK"])
input_string_2 = tf.constant(["fruit", "salad", "UNK"])
out1 = table1.lookup(input_string_1)
out2 = table2.lookup(input_string_2)
out1, out2 = sess.run([out1, out2])
self.assertAllEqual([0, 1, 2, -1], out1)
self.assertAllEqual([-2, 1, -2], out2)
self.assertEquals(vocab_size + oov_buckets, table1.size().eval())
self.assertEquals(vocab_size + oov_buckets, table2.size().eval())
def testSparseTensor(self):
vocab_file = self._createVocabFile("feat_to_id_7.txt")
input_indices = [[0, 0], [0, 1], [2, 0], [2, 2], [3, 0]]
input_shape = [4, 4]
with self.test_session() as sess:
sp_features = tf.SparseTensor(
tf.constant(input_indices, tf.int64),
tf.constant(["brain", "salad", "brain", "surgery", "tarkus"],
tf.string), tf.constant(input_shape, tf.int64))
table = tf.contrib.lookup.IdTableWithHashBuckets(
tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=3),
-1),
1)
table.init.run()
sp_ids = table.lookup(sp_features)
self.assertAllEqual([5], sp_ids.values._shape_as_list())
sp_ids_ind, sp_ids_val, sp_ids_shape = sess.run(
[sp_ids.indices, sp_ids.values, sp_ids.dense_shape])
self.assertAllEqual(input_indices, sp_ids_ind)
self.assertAllEqual([0, 1, 0, 2, 3], sp_ids_val)
self.assertAllEqual(input_shape, sp_ids_shape)
def testIdTableWithHashBucketsWithInvalidHashers(self):
vocab_file = self._createVocabFile("feat_to_id_4.txt")
with self.test_session():
default_value = -1
vocab_size = 3
oov_buckets = 1
lookup_table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.TextFileIdTableInitializer(
vocab_file, vocab_size=vocab_size),
default_value)
with self.assertRaises(TypeError):
tf.contrib.lookup.IdTableWithHashBuckets(
lookup_table, oov_buckets, hasher_spec=1)
table = tf.contrib.lookup.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=tf.contrib.lookup.HasherSpec("my-awesome-hash", None))
input_string = tf.constant(["brain", "salad", "surgery", "UNK"])
with self.assertRaises(ValueError):
table.lookup(input_string)
with self.assertRaises(ValueError):
table = tf.contrib.lookup.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=tf.contrib.lookup.StrongHashSpec([]))
with self.assertRaises(ValueError):
table = tf.contrib.lookup.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=tf.contrib.lookup.StrongHashSpec([1, 2, 3]))
with self.assertRaises(TypeError):
table = tf.contrib.lookup.IdTableWithHashBuckets(
lookup_table,
oov_buckets,
hasher_spec=tf.contrib.lookup.StrongHashSpec([None, 2]))
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册