Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
d251e8c9
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
d251e8c9
编写于
7月 11, 2019
作者:
A
A. Unique TensorFlower
提交者:
TensorFlower Gardener
7月 11, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Change pfor logic to work in nested contexts inside an xla compile call.
PiperOrigin-RevId: 257646455
上级
4fecc9ea
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
41 addition
and
7 deletion
+41
-7
tensorflow/python/ops/parallel_for/control_flow_ops.py
tensorflow/python/ops/parallel_for/control_flow_ops.py
+16
-4
tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py
...flow/python/ops/parallel_for/xla_control_flow_ops_test.py
+25
-3
未找到文件。
tensorflow/python/ops/parallel_for/control_flow_ops.py
浏览文件 @
d251e8c9
...
...
@@ -113,6 +113,21 @@ def _flatten_first_two_dims(x):
PFOR_CONFIG_ARG
=
"pfor_config"
def
_is_under_xla_context
():
"""Check if we are currently inside an XLA compile context."""
g
=
ops
.
get_default_graph
()
while
g
is
not
None
:
control_flow_context
=
g
.
_get_control_flow_context
()
# pylint: disable=protected-access
while
control_flow_context
is
not
None
:
if
control_flow_context
.
IsXLAContext
():
return
True
else
:
control_flow_context
=
control_flow_context
.
outer_context
# If g is a FuncGraph, get its outer_graph.
g
=
getattr
(
g
,
"outer_graph"
,
None
)
return
False
def
pfor
(
loop_fn
,
iters
,
parallel_iterations
=
None
):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
...
...
@@ -162,13 +177,10 @@ def pfor(loop_fn, iters, parallel_iterations=None):
"""
def
f
():
return
_pfor_impl
(
loop_fn
,
iters
,
parallel_iterations
=
parallel_iterations
)
control_flow_context
=
ops
.
get_default_graph
().
_get_control_flow_context
()
# pylint: disable=protected-access
# Note that we wrap into a tf.function if in eager execution mode or under
# XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body.
if
(
context
.
executing_eagerly
()
or
(
control_flow_context
is
not
None
and
control_flow_context
.
IsXLAContext
())):
if
context
.
executing_eagerly
()
or
_is_under_xla_context
():
f
=
function
.
defun
(
f
)
return
f
()
...
...
tensorflow/python/ops/parallel_for/xla_control_flow_ops_test.py
浏览文件 @
d251e8c9
...
...
@@ -22,6 +22,7 @@ from __future__ import print_function
from
tensorflow.python.compiler.xla
import
xla
from
tensorflow.python.framework
import
test_util
from
tensorflow.python.ops
import
array_ops
from
tensorflow.python.ops
import
control_flow_ops
from
tensorflow.python.ops
import
math_ops
from
tensorflow.python.ops.parallel_for
import
control_flow_ops
as
pfor_control_flow_ops
from
tensorflow.python.ops.parallel_for.test_util
import
PForTestCase
...
...
@@ -39,10 +40,31 @@ class PForTest(PForTestCase):
def
vectorized_compute
(
x
):
return
pfor_control_flow_ops
.
vectorized_map
(
compute
,
x
)
result
=
xla
.
compile
(
vectorized_compute
,
inputs
=
[
array_ops
.
ones
((
10
,
5
,
3
))])
result
=
xla
.
compile
(
vectorized_compute
,
inputs
=
[
array_ops
.
ones
((
10
,
5
,
3
))])
self
.
run_and_assert_equal
(
result
,
array_ops
.
ones
((
10
,
1
,
3
)))
def
test_xla_while_loop
(
self
):
if
__name__
==
"__main__"
:
def
compute
(
x
):
return
math_ops
.
reduce_mean
(
x
,
axis
=
0
,
keepdims
=
True
)
def
vectorized_compute
(
x
,
i
):
inp
=
array_ops
.
gather
(
x
,
i
)
output
=
pfor_control_flow_ops
.
vectorized_map
(
compute
,
inp
)
output
.
set_shape
([
5
,
1
])
return
output
def
while_compute
(
x
):
return
control_flow_ops
.
while_loop_v2
(
lambda
i
,
_
:
i
<
10
,
lambda
i
,
y
:
(
i
+
1
,
y
+
vectorized_compute
(
x
,
i
)),
(
0
,
array_ops
.
zeros
([
5
,
1
])))[
1
]
result
=
xla
.
compile
(
while_compute
,
inputs
=
[
array_ops
.
ones
((
10
,
5
,
3
))])
expected
=
array_ops
.
ones
([
5
,
1
])
*
10
self
.
run_and_assert_equal
(
expected
,
result
)
if
__name__
==
'__main__'
:
test
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录