提交 d251e8c9 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Change pfor logic to work in nested contexts inside an xla compile call.

PiperOrigin-RevId: 257646455
上级 4fecc9ea
......@@ -113,6 +113,21 @@ def _flatten_first_two_dims(x):
PFOR_CONFIG_ARG = "pfor_config"
def _is_under_xla_context():
"""Check if we are currently inside an XLA compile context."""
g = ops.get_default_graph()
while g is not None:
control_flow_context = g._get_control_flow_context() # pylint: disable=protected-access
while control_flow_context is not None:
if control_flow_context.IsXLAContext():
return True
else:
control_flow_context = control_flow_context.outer_context
# If g is a FuncGraph, get its outer_graph.
g = getattr(g, "outer_graph", None)
return False
def pfor(loop_fn, iters, parallel_iterations=None):
"""Equivalent to running `loop_fn` `iters` times and stacking the outputs.
......@@ -162,13 +177,10 @@ def pfor(loop_fn, iters, parallel_iterations=None):
"""
def f():
return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
control_flow_context = ops.get_default_graph()._get_control_flow_context() # pylint: disable=protected-access
# Note that we wrap into a tf.function if in eager execution mode or under
# XLA compilation. The latter is so that we don't compile operations like
# tf.placeholder that are created by the loop body.
if (context.executing_eagerly() or
(control_flow_context is not None and
control_flow_context.IsXLAContext())):
if context.executing_eagerly() or _is_under_xla_context():
f = function.defun(f)
return f()
......
......@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.compiler.xla import xla
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.parallel_for import control_flow_ops as pfor_control_flow_ops
from tensorflow.python.ops.parallel_for.test_util import PForTestCase
......@@ -39,10 +40,31 @@ class PForTest(PForTestCase):
def vectorized_compute(x):
return pfor_control_flow_ops.vectorized_map(compute, x)
result = xla.compile(vectorized_compute,
inputs=[array_ops.ones((10, 5, 3))])
result = xla.compile(
vectorized_compute, inputs=[array_ops.ones((10, 5, 3))])
self.run_and_assert_equal(result, array_ops.ones((10, 1, 3)))
def test_xla_while_loop(self):
if __name__ == "__main__":
def compute(x):
return math_ops.reduce_mean(x, axis=0, keepdims=True)
def vectorized_compute(x, i):
inp = array_ops.gather(x, i)
output = pfor_control_flow_ops.vectorized_map(compute, inp)
output.set_shape([5, 1])
return output
def while_compute(x):
return control_flow_ops.while_loop_v2(
lambda i, _: i < 10,
lambda i, y: (i + 1, y + vectorized_compute(x, i)),
(0, array_ops.zeros([5, 1])))[1]
result = xla.compile(while_compute, inputs=[array_ops.ones((10, 5, 3))])
expected = array_ops.ones([5, 1]) * 10
self.run_and_assert_equal(expected, result)
if __name__ == '__main__':
test.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册