提交 866de15d 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Change public fns to support non-tensor inputs.

Fix `index_table_from_tensor` to handle `int32` keys with no oov buckets.
Add better tests for sparse integer inputs.
Change: 150006782
上级 9ee1e726
......@@ -136,6 +136,19 @@ def _embeddings_from_arguments(column,
max_norm=args.max_norm)
def _convert_to_tensors(values):
"""Modifies `values`, if necessary, to convert all to `Tensor` objects.
Args:
values: Dict mapping a key to tensor-like values.
"""
for k in sorted(
# pylint: disable=protected-access
values.keys(),
key=lambda k: k.key if isinstance(k, fc._FeatureColumn) else k):
values[k] = sparse_tensor_py.convert_to_tensor_or_sparse_tensor(values[k])
def _input_from_feature_columns(columns_to_tensors,
feature_columns,
weight_collections,
......@@ -148,6 +161,7 @@ def _input_from_feature_columns(columns_to_tensors,
with variable_scope.variable_scope(scope,
default_name=default_name,
values=columns_to_tensors.values()):
_convert_to_tensors(columns_to_tensors)
output_tensors = []
transformer = _Transformer(columns_to_tensors)
if weight_collections:
......@@ -226,7 +240,8 @@ def input_from_feature_columns(columns_to_tensors,
columns_to_tensors: A mapping from feature column to tensors. 'string' key
means a base feature (not-transformed). It can have FeatureColumn as a
key too. That means that FeatureColumn is already transformed by input
pipeline. For example, `inflow` may have handled transformations.
pipeline. For example, `inflow` may have handled transformations. This
dictionary will be modified by this function.
feature_columns: A set containing all the feature columns. All items in the
set should be instances of classes derived by FeatureColumn.
weight_collections: List of graph collections to which weights are added.
......@@ -235,7 +250,7 @@ def input_from_feature_columns(columns_to_tensors,
scope: Optional scope for variable_scope.
Returns:
A Tensor which can be consumed by hidden layers in the neural network.
`Tensor` which can be consumed by hidden layers in the neural network.
Raises:
ValueError: if FeatureColumn cannot be consumed by a neural network.
......@@ -268,7 +283,8 @@ def sequence_input_from_feature_columns(columns_to_tensors,
columns_to_tensors: A mapping from feature column to tensors. 'string' key
means a base feature (not-transformed). It can have FeatureColumn as a
key too. That means that FeatureColumn is already transformed by input
pipeline. For example, `inflow` may have handled transformations.
pipeline. For example, `inflow` may have handled transformations. This
dictionary will be modified by this function.
feature_columns: A set containing all the feature columns. All items in the
set should be instances of classes derived by FeatureColumn.
weight_collections: List of graph collections to which weights are added.
......@@ -410,7 +426,8 @@ def joint_weighted_sum_from_feature_columns(columns_to_tensors,
columns_to_tensors: A mapping from feature column to tensors. 'string' key
means a base feature (not-transformed). It can have FeatureColumn as a
key too. That means that FeatureColumn is already transformed by input
pipeline. For example, `inflow` may have handled transformations.
pipeline. For example, `inflow` may have handled transformations. This
dictionary will be modified by this function.
feature_columns: A set containing all the feature columns. All items in the
set should be instances of classes derived from FeatureColumn.
num_outputs: An integer specifying number of outputs. Default value is 1.
......@@ -435,6 +452,7 @@ def joint_weighted_sum_from_feature_columns(columns_to_tensors,
scope,
default_name='joint_weighted_sum_from_feature_columns',
values=columns_to_tensors.values()):
_convert_to_tensors(columns_to_tensors)
transformer = _Transformer(columns_to_tensors)
embedding_lookup_arguments = []
for column in sorted(set(feature_columns), key=lambda x: x.key):
......@@ -499,7 +517,8 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
columns_to_tensors: A mapping from feature column to tensors. 'string' key
means a base feature (not-transformed). It can have FeatureColumn as a
key too. That means that FeatureColumn is already transformed by input
pipeline. For example, `inflow` may have handled transformations.
pipeline. For example, `inflow` may have handled transformations. This
dictionary will be modified by this function.
feature_columns: A set containing all the feature columns. All items in the
set should be instances of classes derived from FeatureColumn.
num_outputs: An integer specifying number of outputs. Default value is 1.
......@@ -523,6 +542,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
scope,
default_name='weighted_sum_from_feature_columns',
values=columns_to_tensors.values()):
_convert_to_tensors(columns_to_tensors)
output_tensors = []
column_to_variable = dict()
transformer = _Transformer(columns_to_tensors)
......@@ -677,7 +697,8 @@ def transform_features(features, feature_columns):
```
Args:
features: A dictionary of features.
features: A dictionary mapping string names to `Tensor` or `SparseTensor`
feature values. This dictionary will not be modified by this function.
feature_columns: An iterable containing all the feature columns. All items
should be instances of classes derived from _FeatureColumn.
......@@ -685,15 +706,15 @@ def transform_features(features, feature_columns):
A `dict` mapping FeatureColumn to `Tensor` and `SparseTensor` values.
"""
check_feature_columns(feature_columns)
columns_to_tensor = features.copy()
transformer = _Transformer(columns_to_tensor)
columns_to_tensors = features.copy()
_convert_to_tensors(columns_to_tensors)
transformer = _Transformer(columns_to_tensors)
for column in sorted(set(feature_columns), key=lambda x: x.key):
transformer.transform(column)
keys = list(columns_to_tensor.keys())
for k in keys:
for k in list(columns_to_tensors.keys()):
if k not in feature_columns:
columns_to_tensor.pop(k)
return columns_to_tensor
columns_to_tensors.pop(k)
return columns_to_tensors
def parse_feature_columns_from_sequence_examples(
......@@ -765,6 +786,7 @@ def _log_variable(variable):
def _infer_real_valued_column_for_tensor(name, tensor):
"""Creates a real_valued_column for given tensor and name."""
tensor = sparse_tensor_py.convert_to_tensor_or_sparse_tensor(tensor)
if isinstance(tensor, sparse_tensor_py.SparseTensor):
raise ValueError(
'SparseTensor is not supported for auto detection. Please define '
......@@ -817,6 +839,13 @@ def check_feature_columns(feature_columns):
seen_keys.add(key)
def _sort_columns(columns_or_strings):
return sorted(
# pylint: disable=protected-access
columns_or_strings,
key=lambda k: k.key if isinstance(k, fc._FeatureColumn) else k)
class _Transformer(object):
"""Handles all the transformations defined by FeatureColumn if needed.
......@@ -874,7 +903,7 @@ class _Transformer(object):
Raises:
ValueError: if FeatureColumn cannot be handled by this Transformer.
"""
logging.debug('Transforming feature_column %s', feature_column)
logging.debug('Transforming feature_column %s.', feature_column)
if feature_column in self._columns_to_tensors:
# Feature_column is already transformed.
return self._columns_to_tensors[feature_column]
......
......@@ -1029,13 +1029,12 @@ def index_table_from_tensor(mapping,
name="table_init")
table = HashTable(
init, default_value, shared_name=shared_name, name=hash_table_scope)
if num_oov_buckets:
table = IdTableWithHashBuckets(
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
name=feat_to_id_scope,
key_dtype=dtype)
table = IdTableWithHashBuckets(
table,
num_oov_buckets=num_oov_buckets,
hasher_spec=hasher_spec,
name=feat_to_id_scope,
key_dtype=dtype)
return table
......
......@@ -1341,6 +1341,17 @@ class IndexTableFromTensor(test.TestCase):
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, 3), ids.eval())
def test_int32_index_table_from_tensor_with_no_buckets(self):
with self.test_session():
table = lookup.index_table_from_tensor(
mapping=(42, 1, -1000), dtype=dtypes.int32)
ids = table.lookup(
constant_op.constant((1, -1000, 11), dtype=dtypes.int32))
self.assertRaises(errors_impl.OpError, ids.eval)
data_flow_ops.tables_initializer().run()
self.assertAllEqual((1, 2, -1), ids.eval())
def test_int64_index_table_from_tensor_with_tensor_init(self):
with self.test_session():
table = lookup.index_table_from_tensor(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册