提交 53be8312 编写于 作者: W Wei Ho 提交者: TensorFlower Gardener

Make sure shared_embedding_columns sorts input before using

Change: 136744992
上级 7d8ce2ee
......@@ -74,6 +74,7 @@ from __future__ import print_function
import abc
import collections
import math
import six
from tensorflow.contrib.framework.python.framework import deprecation
from tensorflow.contrib.layers.python.layers import layers
......@@ -957,13 +958,18 @@ def shared_embedding_columns(sparse_id_columns,
Raises:
ValueError: if sparse_id_columns is empty, or its elements are not
compatible with each other.
TypeError: if at least one element of sparse_id_columns is not a
`SparseTensor`.
TypeError: if `sparse_id_columns` is not a sequence or is a string. If at
least one element of `sparse_id_columns` is not a `SparseTensor`.
"""
if combiner is None:
logging.warn("The default value of combiner will change from \"mean\" "
"to \"sqrtn\" after 2016/11/01.")
combiner = "mean"
if (not isinstance(sparse_id_columns, collections.Sequence) or
isinstance(sparse_id_columns, six.string_types)):
raise TypeError(
"sparse_id_columns must be a non-string sequence (ex: list or tuple) "
"instead of type {}.".format(type(sparse_id_columns)))
if len(sparse_id_columns) < 1:
raise ValueError("The input sparse_id_columns should have at least one "
"element.")
......@@ -972,8 +978,6 @@ def shared_embedding_columns(sparse_id_columns,
raise TypeError("Elements of sparse_id_columns must be _SparseColumn, but"
"{} is not.".format(sparse_id_column))
if not isinstance(sparse_id_columns, list):
sparse_id_columns = list(sparse_id_columns)
if len(sparse_id_columns) == 1:
return [
_EmbeddingColumn(sparse_id_columns[0], dimension, combiner, initializer,
......@@ -988,14 +992,17 @@ def shared_embedding_columns(sparse_id_columns,
raise ValueError("The input sparse id columns are not compatible.")
# Construct the shared name and size for shared embedding space.
if not shared_embedding_name:
if len(sparse_id_columns) <= 3:
# Sort the columns so that shared_embedding_name will be deterministic
# even if users pass in unsorted columns from a dict or something.
sorted_columns = sorted(sparse_id_columns)
if len(sorted_columns) <= 3:
shared_embedding_name = "_".join([column.name
for column in sparse_id_columns])
for column in sorted_columns])
else:
shared_embedding_name = "_".join([column.name
for column in sparse_id_columns[0:3]])
for column in sorted_columns[0:3]])
shared_embedding_name += (
"_plus_{}_others".format(len(sparse_id_columns)-3))
"_plus_{}_others".format(len(sorted_columns)-3))
shared_embedding_name += "_shared_embedding"
shared_vocab_size = sparse_id_columns[0].length
......
......@@ -137,6 +137,35 @@ class FeatureColumnTest(tf.test.TestCase):
for i in range(len(d1_value)):
self.assertAllClose(d1_value[i], e1_value[i])
def testSharedEmbeddingColumnDeterminism(self):
# Tests determinism in auto-generated shared_embedding_name.
sparse_id_columns = tuple([
tf.contrib.layers.sparse_column_with_keys(k, ["foo", "bar"])
for k in ["07", "02", "00", "03", "05", "01", "09", "06", "04", "08"]
])
output = tf.contrib.layers.shared_embedding_columns(
sparse_id_columns, dimension=2, combiner="mean")
self.assertEqual(len(output), 10)
for x in output:
self.assertEqual(x.shared_embedding_name,
"00_01_02_plus_7_others_shared_embedding")
def testSharedEmbeddingColumnErrors(self):
# Tries passing in a string.
with self.assertRaises(TypeError):
invalid_string = "Invalid string."
tf.contrib.layers.shared_embedding_columns(
invalid_string, dimension=2, combiner="mean")
# Tries passing in a set of sparse columns.
with self.assertRaises(TypeError):
invalid_set = set([
tf.contrib.layers.sparse_column_with_keys("a", ["foo", "bar"]),
tf.contrib.layers.sparse_column_with_keys("b", ["foo", "bar"]),
])
tf.contrib.layers.shared_embedding_columns(
invalid_set, dimension=2, combiner="mean")
def testOneHotColumn(self):
a = tf.contrib.layers.sparse_column_with_keys("a", ["a", "b", "c", "d"])
onehot_a = tf.contrib.layers.one_hot_column(a)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册