Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
53be8312
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
53be8312
编写于
10月 20, 2016
作者:
W
Wei Ho
提交者:
TensorFlower Gardener
10月 20, 2016
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Make sure shared_embedding_columns sorts input before using
Change: 136744992
上级
7d8ce2ee
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
44 addition
and
8 deletion
+44
-8
tensorflow/contrib/layers/python/layers/feature_column.py
tensorflow/contrib/layers/python/layers/feature_column.py
+15
-8
tensorflow/contrib/layers/python/layers/feature_column_test.py
...rflow/contrib/layers/python/layers/feature_column_test.py
+29
-0
未找到文件。
tensorflow/contrib/layers/python/layers/feature_column.py
浏览文件 @
53be8312
...
...
@@ -74,6 +74,7 @@ from __future__ import print_function
import
abc
import
collections
import
math
import
six
from
tensorflow.contrib.framework.python.framework
import
deprecation
from
tensorflow.contrib.layers.python.layers
import
layers
...
...
@@ -957,13 +958,18 @@ def shared_embedding_columns(sparse_id_columns,
Raises:
ValueError: if sparse_id_columns is empty, or its elements are not
compatible with each other.
TypeError: if
at least one element of sparse_id_columns is not a
`SparseTensor`.
TypeError: if
`sparse_id_columns` is not a sequence or is a string. If at
least one element of `sparse_id_columns` is not a
`SparseTensor`.
"""
if
combiner
is
None
:
logging
.
warn
(
"The default value of combiner will change from
\"
mean
\"
"
"to
\"
sqrtn
\"
after 2016/11/01."
)
combiner
=
"mean"
if
(
not
isinstance
(
sparse_id_columns
,
collections
.
Sequence
)
or
isinstance
(
sparse_id_columns
,
six
.
string_types
)):
raise
TypeError
(
"sparse_id_columns must be a non-string sequence (ex: list or tuple) "
"instead of type {}."
.
format
(
type
(
sparse_id_columns
)))
if
len
(
sparse_id_columns
)
<
1
:
raise
ValueError
(
"The input sparse_id_columns should have at least one "
"element."
)
...
...
@@ -972,8 +978,6 @@ def shared_embedding_columns(sparse_id_columns,
raise
TypeError
(
"Elements of sparse_id_columns must be _SparseColumn, but"
"{} is not."
.
format
(
sparse_id_column
))
if
not
isinstance
(
sparse_id_columns
,
list
):
sparse_id_columns
=
list
(
sparse_id_columns
)
if
len
(
sparse_id_columns
)
==
1
:
return
[
_EmbeddingColumn
(
sparse_id_columns
[
0
],
dimension
,
combiner
,
initializer
,
...
...
@@ -988,14 +992,17 @@ def shared_embedding_columns(sparse_id_columns,
raise
ValueError
(
"The input sparse id columns are not compatible."
)
# Construct the shared name and size for shared embedding space.
if
not
shared_embedding_name
:
if
len
(
sparse_id_columns
)
<=
3
:
# Sort the columns so that shared_embedding_name will be deterministic
# even if users pass in unsorted columns from a dict or something.
sorted_columns
=
sorted
(
sparse_id_columns
)
if
len
(
sorted_columns
)
<=
3
:
shared_embedding_name
=
"_"
.
join
([
column
.
name
for
column
in
s
parse_i
d_columns
])
for
column
in
s
orte
d_columns
])
else
:
shared_embedding_name
=
"_"
.
join
([
column
.
name
for
column
in
s
parse_i
d_columns
[
0
:
3
]])
for
column
in
s
orte
d_columns
[
0
:
3
]])
shared_embedding_name
+=
(
"_plus_{}_others"
.
format
(
len
(
s
parse_i
d_columns
)
-
3
))
"_plus_{}_others"
.
format
(
len
(
s
orte
d_columns
)
-
3
))
shared_embedding_name
+=
"_shared_embedding"
shared_vocab_size
=
sparse_id_columns
[
0
].
length
...
...
tensorflow/contrib/layers/python/layers/feature_column_test.py
浏览文件 @
53be8312
...
...
@@ -137,6 +137,35 @@ class FeatureColumnTest(tf.test.TestCase):
for
i
in
range
(
len
(
d1_value
)):
self
.
assertAllClose
(
d1_value
[
i
],
e1_value
[
i
])
def
testSharedEmbeddingColumnDeterminism
(
self
):
# Tests determinism in auto-generated shared_embedding_name.
sparse_id_columns
=
tuple
([
tf
.
contrib
.
layers
.
sparse_column_with_keys
(
k
,
[
"foo"
,
"bar"
])
for
k
in
[
"07"
,
"02"
,
"00"
,
"03"
,
"05"
,
"01"
,
"09"
,
"06"
,
"04"
,
"08"
]
])
output
=
tf
.
contrib
.
layers
.
shared_embedding_columns
(
sparse_id_columns
,
dimension
=
2
,
combiner
=
"mean"
)
self
.
assertEqual
(
len
(
output
),
10
)
for
x
in
output
:
self
.
assertEqual
(
x
.
shared_embedding_name
,
"00_01_02_plus_7_others_shared_embedding"
)
def
testSharedEmbeddingColumnErrors
(
self
):
# Tries passing in a string.
with
self
.
assertRaises
(
TypeError
):
invalid_string
=
"Invalid string."
tf
.
contrib
.
layers
.
shared_embedding_columns
(
invalid_string
,
dimension
=
2
,
combiner
=
"mean"
)
# Tries passing in a set of sparse columns.
with
self
.
assertRaises
(
TypeError
):
invalid_set
=
set
([
tf
.
contrib
.
layers
.
sparse_column_with_keys
(
"a"
,
[
"foo"
,
"bar"
]),
tf
.
contrib
.
layers
.
sparse_column_with_keys
(
"b"
,
[
"foo"
,
"bar"
]),
])
tf
.
contrib
.
layers
.
shared_embedding_columns
(
invalid_set
,
dimension
=
2
,
combiner
=
"mean"
)
def
testOneHotColumn
(
self
):
a
=
tf
.
contrib
.
layers
.
sparse_column_with_keys
(
"a"
,
[
"a"
,
"b"
,
"c"
,
"d"
])
onehot_a
=
tf
.
contrib
.
layers
.
one_hot_column
(
a
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录