Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
ec0b5d1c
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ec0b5d1c
编写于
7月 16, 2019
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
7月 16, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Pfor: make unsorted_segment_sum converter safe for int64 types.
PiperOrigin-RevId: 258450422
上级
02433bb9
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
25 addition
and
16 deletion
+25
-16
tensorflow/python/ops/parallel_for/math_test.py
tensorflow/python/ops/parallel_for/math_test.py
+18
-13
tensorflow/python/ops/parallel_for/pfor.py
tensorflow/python/ops/parallel_for/pfor.py
+7
-3
未找到文件。
tensorflow/python/ops/parallel_for/math_test.py
浏览文件 @
ec0b5d1c
...
@@ -422,9 +422,13 @@ class MathTest(PForTestCase):
...
@@ -422,9 +422,13 @@ class MathTest(PForTestCase):
def
test_unsorted_segment_sum
(
self
):
def
test_unsorted_segment_sum
(
self
):
t
=
random_ops
.
random_uniform
([
3
,
3
,
2
])
t
=
random_ops
.
random_uniform
([
3
,
3
,
2
])
segment_ids
=
constant_op
.
constant
([[
0
,
0
,
2
],
[
0
,
1
,
2
],
[
2
,
2
,
2
]])
for
segment_ids_dtype
in
(
dtypes
.
int32
,
dtypes
.
int64
):
num_segments
=
3
for
num_segments_dtype
in
(
dtypes
.
int32
,
dtypes
.
int64
):
segment_ids
=
constant_op
.
constant
([[
0
,
0
,
2
],
[
0
,
1
,
2
],
[
2
,
2
,
2
]],
dtype
=
segment_ids_dtype
)
num_segments
=
constant_op
.
constant
(
3
,
dtype
=
num_segments_dtype
)
# pylint: disable=cell-var-from-loop
def
loop_fn
(
i
):
def
loop_fn
(
i
):
data
=
array_ops
.
gather
(
t
,
i
)
data
=
array_ops
.
gather
(
t
,
i
)
data_0
=
array_ops
.
gather
(
t
,
0
)
data_0
=
array_ops
.
gather
(
t
,
0
)
...
@@ -433,6 +437,7 @@ class MathTest(PForTestCase):
...
@@ -433,6 +437,7 @@ class MathTest(PForTestCase):
return
(
math_ops
.
unsorted_segment_sum
(
data
,
seg_ids
,
num_segments
),
return
(
math_ops
.
unsorted_segment_sum
(
data
,
seg_ids
,
num_segments
),
math_ops
.
unsorted_segment_sum
(
data_0
,
seg_ids
,
num_segments
),
math_ops
.
unsorted_segment_sum
(
data_0
,
seg_ids
,
num_segments
),
math_ops
.
unsorted_segment_sum
(
data
,
seg_ids_0
,
num_segments
))
math_ops
.
unsorted_segment_sum
(
data
,
seg_ids_0
,
num_segments
))
# pylint: enable=cell-var-from-loop
self
.
_test_loop_fn
(
loop_fn
,
3
,
[
dtypes
.
float32
]
*
3
)
self
.
_test_loop_fn
(
loop_fn
,
3
,
[
dtypes
.
float32
]
*
3
)
...
...
tensorflow/python/ops/parallel_for/pfor.py
浏览文件 @
ec0b5d1c
...
@@ -2198,10 +2198,14 @@ def _convert_unsortedsegmentsum(pfor_input):
...
@@ -2198,10 +2198,14 @@ def _convert_unsortedsegmentsum(pfor_input):
segment_ids
=
pfor_input
.
stacked_input
(
1
)
segment_ids
=
pfor_input
.
stacked_input
(
1
)
# TODO(agarwal): handle stacked?
# TODO(agarwal): handle stacked?
num_segments
=
pfor_input
.
unstacked_input
(
2
)
num_segments
=
pfor_input
.
unstacked_input
(
2
)
segment_shape
=
array_ops
.
shape
(
segment_ids
)
if
segment_ids
.
dtype
!=
num_segments
.
dtype
:
segment_ids
=
math_ops
.
cast
(
segment_ids
,
dtypes
.
int64
)
num_segments
=
math_ops
.
cast
(
num_segments
,
dtypes
.
int64
)
dtype
=
segment_ids
.
dtype
segment_shape
=
array_ops
.
shape
(
segment_ids
,
out_type
=
dtype
)
n
=
segment_shape
[
0
]
n
=
segment_shape
[
0
]
ones
=
array_ops
.
ones_like
(
segment_shape
)[
1
:]
ones
=
array_ops
.
ones_like
(
segment_shape
,
dtype
=
dtype
)[
1
:]
segment_offset
=
num_segments
*
math_ops
.
range
(
n
)
segment_offset
=
num_segments
*
math_ops
.
range
(
n
,
dtype
=
dtype
)
segment_offset
=
array_ops
.
reshape
(
segment_offset
,
segment_offset
=
array_ops
.
reshape
(
segment_offset
,
array_ops
.
concat
([[
n
],
ones
],
axis
=
0
))
array_ops
.
concat
([[
n
],
ones
],
axis
=
0
))
segment_ids
+=
segment_offset
segment_ids
+=
segment_offset
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录