Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
822d64f0
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
822d64f0
编写于
5月 29, 2017
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
5月 29, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix embedding_lookup() bug where normalization did not work with ids of rank != 1.
PiperOrigin-RevId: 157422220
上级
8cad6b82
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
43 addition
and
7 deletion
+43
-7
tensorflow/python/kernel_tests/embedding_ops_test.py
tensorflow/python/kernel_tests/embedding_ops_test.py
+25
-0
tensorflow/python/ops/embedding_ops.py
tensorflow/python/ops/embedding_ops.py
+18
-7
未找到文件。
tensorflow/python/kernel_tests/embedding_ops_test.py
浏览文件 @
822d64f0
...
...
@@ -547,6 +547,31 @@ class EmbeddingLookupTest(test.TestCase):
sharded
=
embedding_ops
.
embedding_lookup
(
split_params
,
ids
).
eval
()
self
.
assertAllEqual
(
simple
,
sharded
)
def
testHigherRankMaxNorm
(
self
):
np
.
random
.
seed
(
8
)
with
self
.
test_session
():
for
params_shape
in
(
12
,),
(
6
,
3
):
params
=
2
*
np
.
ones
(
params_shape
)
params_norm
=
params
/
np
.
sqrt
(
np
.
sum
(
params
*
params
,
tuple
(
range
(
params
.
ndim
)[
1
:]),
keepdims
=
True
))
for
ids_shape
in
(),
(
3
),
(
4
,
3
),
(
2
,
3
,
4
):
ids
=
np
.
random
.
randint
(
params
.
shape
[
0
],
size
=
np
.
prod
(
ids_shape
,
dtype
=
np
.
int64
)).
reshape
(
ids_shape
)
# Compare nonsharded to gather
simple
=
embedding_ops
.
embedding_lookup
(
params
,
ids
,
max_norm
=
1.0
).
eval
()
self
.
assertAllEqual
(
simple
,
array_ops
.
gather
(
params_norm
,
ids
).
eval
())
# Run a few random sharded versions
for
procs
in
1
,
2
,
3
:
stride
=
procs
*
math_ops
.
range
(
params
.
shape
[
0
]
//
procs
)
split_params
=
[
array_ops
.
gather
(
params
,
stride
+
p
)
for
p
in
xrange
(
procs
)
]
sharded
=
embedding_ops
.
embedding_lookup
(
split_params
,
ids
,
max_norm
=
1.0
).
eval
()
self
.
assertAllEqual
(
simple
,
sharded
)
class
EmbeddingLookupSparseTest
(
test
.
TestCase
):
...
...
tensorflow/python/ops/embedding_ops.py
浏览文件 @
822d64f0
...
...
@@ -103,14 +103,25 @@ def embedding_lookup(params, ids, partition_strategy="mod", name=None,
params
=
list
(
params
)
# Iterate to get the underlying Variables.
if
not
isinstance
(
params
,
list
):
params
=
[
params
]
def
maybe_normalize
(
x
):
if
max_norm
is
not
None
:
if
x
.
get_shape
().
ndims
is
not
None
:
ndims
=
x
.
get_shape
().
ndims
else
:
ndims
=
array_ops
.
size
(
array_ops
.
shape
(
x
))
return
clip_ops
.
clip_by_norm
(
x
,
max_norm
,
axes
=
list
(
range
(
1
,
ndims
)))
return
x
"""Normalizes the embeddings in x if max_norm is not None."""
if
max_norm
is
None
:
return
x
static
=
True
ids_rank
=
ops
.
convert_to_tensor
(
ids
).
get_shape
().
ndims
if
ids_rank
is
None
:
ids_rank
=
array_ops
.
rank
(
ids
)
static
=
False
x_rank
=
x
.
get_shape
().
ndims
if
x_rank
is
None
:
x_rank
=
array_ops
.
rank
(
x
)
static
=
False
return
clip_ops
.
clip_by_norm
(
x
,
max_norm
,
axes
=
list
(
range
(
ids_rank
,
x_rank
))
if
static
else
math_ops
.
range
(
ids_rank
,
x_rank
))
with
ops
.
name_scope
(
name
,
"embedding_lookup"
,
params
+
[
ids
])
as
name
:
np
=
len
(
params
)
# Number of partitions
# Preserve the resource variable status to avoid accidental dense reads.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录