提交 e06a4a2e 编写于 作者: Y Yutaka Leon 提交者: TensorFlower Gardener

- Add a generic immutable hash table op that is initialized once and used to...

- Add a generic immutable hash table op that is initialized once and used to map key tensors to values.

- Add a table initializer given the keys and values as tensors.

Example use case:

  keys = tf.constant([0, 1], tf.int64)
  values = tf.constant(["hello", "world"])
  default_value = "UNK"

  table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(keys, values), default_value)
  input_tensor = tf.constant([0, 2], tf.int64)
  out = table.lookup(input_tensor).
  tf.initialze_all_tables().run()  # or table.init.run()
  print out.eval()   # Returns ["hello", "UNK"]
Change: 117301920
上级 3ae663cc
......@@ -17,6 +17,7 @@ py_library(
"//tensorflow/contrib/distributions:distributions_py",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/testing:testing_py",
"//tensorflow/contrib/util:util_py",
],
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""contrib module containing volatile or experimental code."""
from __future__ import absolute_import
......@@ -24,5 +23,6 @@ from tensorflow.contrib import ctc
from tensorflow.contrib import distributions
from tensorflow.contrib import layers
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import testing
from tensorflow.contrib import util
# Description:
# contains parts of TensorFlow that are experimental or unstable and which are not supported.
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
package(default_visibility = ["//tensorflow:__subpackages__"])
py_library(
name = "lookup_py",
srcs = [
"__init__.py",
"lookup_ops.py",
],
srcs_version = "PY2AND3",
)
py_test(
name = "lookup_ops_test",
srcs = ["lookup_ops_test.py"],
srcs_version = "PY2AND3",
deps = [
":lookup_py",
"//third_party/py/numpy",
"//third_party/py/tensorflow",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
filegroup(
name = "all_files",
srcs = glob(
["**/*"],
exclude = [
"**/METADATA",
"**/OWNERS",
],
),
visibility = ["//tensorflow:__subpackages__"],
)
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Ops for lookup operations.
@@LookupInterface
@@InitializableLookupTableBase
@@HashTable
@@TableInitializerBase
@@KeyValueTensorInitializer
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.lookup.lookup_ops import *
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Lookup table Operations."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import gen_data_flow_ops
class LookupInterface(object):
"""Represent a lookup table that persists across different steps."""
def __init__(self, key_dtype, value_dtype, name):
"""Construct a lookup table interface.
Args:
key_dtype: The table key type.
value_dtype: The table value type.
name: A name for the operation (optional).
"""
self._key_dtype = dtypes.as_dtype(key_dtype)
self._value_dtype = dtypes.as_dtype(value_dtype)
self._name = name
@property
def key_dtype(self):
"""The table key dtype."""
return self._key_dtype
@property
def value_dtype(self):
"""The table value dtype."""
return self._value_dtype
@property
def name(self):
"""The name of the table."""
return self._name
@property
def init(self):
"""The table initialization op."""
raise NotImplementedError
def size(self, name=None):
"""Compute the number of elements in this table."""
raise NotImplementedError
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values."""
raise NotImplementedError
def _check_table_dtypes(self, key_dtype, value_dtype):
"""Check that the given key_dtype and value_dtype matches the table dtypes.
Args:
key_dtype: The key data type to check.
value_dtype: The value data type to check.
Raises:
TypeError: when 'key_dtype' or 'value_dtype' doesn't match the table data
types.
"""
if key_dtype != self.key_dtype:
raise TypeError("Invalid key dtype, expected %s but got %s." %
(self.key_dtype, key_dtype))
if value_dtype != self.value_dtype:
raise TypeError("Invalid value dtype, expected %s but got %s." %
(self.value_dtype, value_dtype))
class InitializableLookupTableBase(LookupInterface):
"""Initializable lookup table interface.
An initializable lookup tables persist across different steps.
"""
def __init__(self, table_ref, default_value, initializer):
"""Construct a table object from a table reference.
If requires a table initializer object (subclass of `TableInitializerBase`).
It provides the table key and value types, as well as the op to initialize
the table. The caller is responsible to execute the initialization op.
Args:
table_ref: The table reference, i.e. the output of the lookup table ops.
default_value: The value to use if a key is missing in the table.
initializer: The table initializer to use.
"""
super(InitializableLookupTableBase, self).__init__(
initializer.key_dtype, initializer.value_dtype,
table_ref.op.name.split("/")[-1])
self._table_ref = table_ref
self._default_value = ops.convert_to_tensor(default_value,
dtype=self._value_dtype)
self._default_value.get_shape().merge_with(tensor_shape.scalar())
self._init = initializer.initialize(self)
@property
def table_ref(self):
"""Get the underlying table reference."""
return self._table_ref
@property
def default_value(self):
"""The default value of the table."""
return self._default_value
@property
def init(self):
"""The table initialization op."""
return self._init
def size(self, name=None):
"""Compute the number of elements in this table.
Args:
name: A name for the operation (optional).
Returns:
A scalar tensor containing the number of elements in this table.
"""
if name is None:
name = "%s_Size" % self._name
# pylint: disable=protected-access
return gen_data_flow_ops._lookup_table_size(self._table_ref, name=name)
# pylint: enable=protected-access
def lookup(self, keys, name=None):
"""Looks up `keys` in a table, outputs the corresponding values.
The `default_value` is use for keys not present in the table.
Args:
keys: Keys to look up. May be either a `SparseTensor` or dense `Tensor`.
name: A name for the operation (optional).
Returns:
A `SparseTensor` if keys are sparse, otherwise a dense `Tensor`.
Raises:
TypeError: when `keys` or `default_value` doesn't match the table data
types.
"""
if name is None:
name = "%s_lookup_table_find" % self._name
key_tensor = keys
if isinstance(keys, ops.SparseTensor):
key_tensor = keys.values
if keys.dtype != self._key_dtype:
raise TypeError("Signature mismatch. Keys must be dtype %s, got %s." %
(self._key_dtype, keys.dtype))
# pylint: disable=protected-access
values = gen_data_flow_ops._lookup_table_find(self._table_ref,
key_tensor,
self._default_value,
name=name)
# pylint: enable=protected-access
if isinstance(keys, ops.SparseTensor):
return ops.SparseTensor(keys.indices, values, keys.shape)
else:
return values
class HashTable(InitializableLookupTableBase):
"""A generic hash table implementation.
Example usage:
```python
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values), -1)
out = table.lookup(input_tensor).
table.init.run()
print out.eval()
```
"""
def __init__(self, initializer, default_value, shared_name=None, name=None):
"""Creates a non-initialized `HashTable` object.
Creates a table, the type of its keys and values are specified by the
initializer.
Before using the table you will have to initialize it. After initialization
the table will be immutable.
Args:
initializer: The table initializer to use.
default_value: The value to use if a key is missing in the table.
shared_name: If non-empty, this table will be shared under
the given name across multiple sessions.
name: A name for the operation (optional).
Returns:
A `HashTable` object.
"""
with ops.op_scope([initializer], name, "hash_table"):
# pylint: disable=protected-access
table_ref = gen_data_flow_ops._hash_table(
shared_name=shared_name,
key_dtype=initializer.key_dtype,
value_dtype=initializer.value_dtype,
name=name)
# pylint: enable=protected-access
super(HashTable, self).__init__(table_ref, default_value, initializer)
class TableInitializerBase(object):
"""Base class for lookup table initializers."""
def __init__(self, key_dtype, value_dtype):
"""Construct a table initializer object.
Args:
key_dtype: Type of the table keys.
value_dtype: Type of the table values.
"""
self._key_dtype = dtypes.as_dtype(key_dtype)
self._value_dtype = dtypes.as_dtype(value_dtype)
@property
def key_dtype(self):
"""The expected table key dtype."""
return self._key_dtype
@property
def value_dtype(self):
"""The expected table value dtype."""
return self._value_dtype
def initialize(self, table):
"""Returns the table initialization op."""
raise NotImplementedError
class KeyValueTensorInitializer(TableInitializerBase):
"""Table initializers given `keys` and `values` tensors."""
def __init__(self, keys, values, key_dtype=None, value_dtype=None, name=None):
"""Constructs a table initializer object based on keys and values tensors.
Args:
keys: The tensor for the keys.
values: The tensor for the values.
key_dtype: The `keys` data type. Used when `keys` is a python array.
value_dtype: The `values` data type. Used when `values` is a python array.
name: A name for the operation (optional).
"""
with ops.op_scope([keys, values], name, "key_value_init") as scope:
self._keys = ops.convert_to_tensor(keys, dtype=key_dtype, name="keys")
self._values = ops.convert_to_tensor(values,
dtype=value_dtype,
name="values")
self._name = scope
super(KeyValueTensorInitializer, self).__init__(self._keys.dtype,
self._values.dtype)
def initialize(self, table):
"""Initializes the given `table` with `keys` and `values` tensors.
Args:
table: The table to initialize.
Returns:
The operation that initializes the table.
Raises:
TypeError: when the keys and values data types do not match the table
key and value data types.
"""
# pylint: disable=protected-access
table._check_table_dtypes(self._keys.dtype, self._values.dtype)
with ops.op_scope([table], self._name) as scope:
init_op = gen_data_flow_ops._initialize_table(table.table_ref,
self._keys,
self._values,
name=scope)
# pylint: enable=protected-access
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, init_op)
return init_op
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tf.contrib.lookup.lookup_ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
class HashTableOpTest(tf.test.TestCase):
def testHashTable(self):
with self.test_session():
default_val = -1
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = tf.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableFindHighRank(self):
with self.test_session():
default_val = -1
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = tf.constant([["brain", "salad"], ["tank", "tarkus"]])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([[0, 1], [-1, -1]], result)
def testHashTableInitWithPythonArrays(self):
with self.test_session():
default_val = -1
keys = ["brain", "salad", "surgery"]
values = [0, 1, 2]
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys,
values,
value_dtype=tf.int64),
default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = tf.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableInitWithNumPyArrays(self):
with self.test_session():
default_val = -1
keys = np.array(["brain", "salad", "surgery"], dtype=np.str)
values = np.array([0, 1, 2], dtype=np.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
self.assertAllEqual(3, table.size().eval())
input_string = tf.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testMultipleHashTables(self):
with self.test_session() as sess:
default_val = -1
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table1 = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table2 = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table3 = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
tf.initialize_all_tables().run()
self.assertAllEqual(3, table1.size().eval())
self.assertAllEqual(3, table2.size().eval())
self.assertAllEqual(3, table3.size().eval())
input_string = tf.constant(["brain", "salad", "tank"])
output1 = table1.lookup(input_string)
output2 = table2.lookup(input_string)
output3 = table3.lookup(input_string)
out1, out2, out3 = sess.run([output1, output2, output3])
self.assertAllEqual([0, 1, -1], out1)
self.assertAllEqual([0, 1, -1], out2)
self.assertAllEqual([0, 1, -1], out3)
def testHashTableWithTensorDefault(self):
with self.test_session():
default_val = tf.constant(-1, tf.int64)
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
input_string = tf.constant(["brain", "salad", "tank"])
output = table.lookup(input_string)
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
def testHashTableWithSparseTensorInput(self):
with self.test_session() as sess:
default_val = tf.constant(-1, tf.int64)
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
sp_indices = [[0, 0], [0, 1], [1, 0]]
sp_shape = [2, 2]
input_tensor = tf.SparseTensor(
tf.constant(sp_indices, tf.int64),
tf.constant(["brain", "salad", "tank"]),
tf.constant(sp_shape, tf.int64))
output = table.lookup(input_tensor)
out_indices, out_values, out_shape = sess.run(output)
self.assertAllEqual([0, 1, -1], out_values)
self.assertAllEqual(sp_indices, out_indices)
self.assertAllEqual(sp_shape, out_shape)
def testSignatureMismatch(self):
with self.test_session():
default_val = -1
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
input_string = tf.constant([1, 2, 3], tf.int64)
with self.assertRaises(TypeError):
table.lookup(input_string)
with self.assertRaises(TypeError):
tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values), "UNK")
def testDTypes(self):
with self.test_session():
default_val = -1
with self.assertRaises(TypeError):
tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(
["a"], [1], [tf.string], tf.int64), default_val)
def testNotInitialized(self):
with self.test_session():
default_val = -1
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(
["a"],
[1],
value_dtype=tf.int64),
default_val)
input_string = tf.constant(["brain", "salad", "surgery"])
output = table.lookup(input_string)
with self.assertRaisesOpError("Table not initialized"):
output.eval()
def testInitializeTwice(self):
with self.test_session():
default_val = -1
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2], tf.int64)
table = tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys, values),
default_val)
table.init.run()
with self.assertRaisesOpError("Table already initialized"):
table.init.run()
def testInitializationWithInvalidDimensions(self):
with self.test_session():
default_val = -1
keys = tf.constant(["brain", "salad", "surgery"])
values = tf.constant([0, 1, 2, 3, 4], tf.int64)
with self.assertRaises(ValueError):
tf.contrib.lookup.HashTable(
tf.contrib.lookup.KeyValueTensorInitializer(keys,
values), default_val)
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册