diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py index 5c115f7ae311ddabef1ff6d7279d724bb1e18f85..a8a65dde131af71cd0189365457908d76f113bf1 100644 --- a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import script_ops from tensorflow.python.platform import test @@ -500,10 +501,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase): def testMapAndBatchControlFlow(self, numa_aware): def map_fn(x): - previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2 - control_flow_ops.ENABLE_COND_V2 = True + previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = True return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x) - control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value + control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value return return_value dataset = dataset_ops.Dataset.range(100).apply( diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py index df3cebd2e0c2f37711dc41cf60409c2660bf3e2c..0e48d3c87587e8315a5b0c3a77ecd487debc1334 100644 --- a/tensorflow/python/framework/test_util.py +++ b/tensorflow/python/framework/test_util.py @@ -67,9 +67,8 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import versions from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import script_ops -from tensorflow.python.ops import tensor_array_ops from tensorflow.python.ops import variables from tensorflow.python.platform import googletest from tensorflow.python.platform import tf_logging as logging @@ -409,42 +408,12 @@ def enable_control_flow_v2(fn): """ def wrapper(*args, **kwargs): - enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2 - enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2 - enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 - control_flow_ops.ENABLE_COND_V2 = True - control_flow_ops.ENABLE_WHILE_V2 = True - tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True + enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = True try: fn(*args, **kwargs) finally: - control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old - control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old - tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old - - return wrapper - - -def enable_tensor_array_v2(fn): - """Decorator for enabling _GraphTensorArrayV2 on a test. - - Note this enables _GraphTensorArrayV2 after running the test class's - setup/teardown methods. - - Args: - fn: the function to be wrapped - - Returns: - The wrapped function - """ - - def wrapper(*args, **kwargs): - enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 - tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True - try: - fn(*args, **kwargs) - finally: - tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old + control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old return wrapper @@ -493,7 +462,7 @@ def with_control_flow_v2(cls): Returns: cls with new test methods added """ - if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: return cls for name, value in cls.__dict__.copy().items(): diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 0fd293ebba3044097453c18fb625fc0dee19b19f..21ded25a116455194760dca11622b04ec6038ed6 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import functional_ops from tensorflow.python.ops import gen_array_ops @@ -700,7 +701,8 @@ class ControlFlowTest(test.TestCase): v1_msg = "The two structures don't have the same nested structure" v2_msg = "Outputs of true_fn and false_fn must have the same structure" with self.assertRaisesRegexp( - ValueError, v2_msg if control_flow_ops.ENABLE_COND_V2 else v1_msg): + ValueError, + v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): r = control_flow_ops.cond(pred, fn1, fn2) self.evaluate(r) @@ -859,7 +861,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0) # v1 control flow gets None second derivative for some reason. - if not control_flow_ops.ENABLE_COND_V2: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: self.assertIsNone(grad_grad) return @@ -949,7 +951,7 @@ class ControlFlowTest(test.TestCase): # In defuns, all prints should execute in program order. # This doesn't work with legacy control flow. - if control_flow_ops.ENABLE_COND_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: @eager_function.defun def cond(): @@ -1003,7 +1005,7 @@ class ControlFlowTest(test.TestCase): # In defuns, all prints should execute in program order. # This doesn't work with legacy control flow. - if control_flow_ops.ENABLE_WHILE_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: @eager_function.defun def while_loop(): @@ -1161,7 +1163,7 @@ class ControlFlowTest(test.TestCase): gs = gradients_impl.gradients(loop_no_xla, v) self.evaluate(gs) # This should execute without error. - if control_flow_ops.ENABLE_WHILE_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() with self.assertRaisesRegexp( @@ -1219,7 +1221,7 @@ class ControlFlowTest(test.TestCase): lambda i, x: (i + 1, v * x), (0, 1.0), maximum_iterations=max_iter_holder[0]) - if control_flow_ops.ENABLE_WHILE_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: xla_context = control_flow_ops.XLAControlFlowContext() xla_context.Enter() with self.assertRaisesRegexp( @@ -1863,7 +1865,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(sess.run(grad, {pred: True}), 8.0) self.assertEqual(sess.run(grad, {pred: False}), 0.0) - if not control_flow_ops.ENABLE_WHILE_V2: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: return self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0) @@ -2399,7 +2401,7 @@ class ControlFlowTest(test.TestCase): # outer_loop(x) = g(g(x)) = 4x + 81 # outer_loop'(x) = 4 # Note that v1 control flow gets 4.0 as well if the cond is removed. - if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: self.assertEqual(grad, 4.0) def testWhile_NestedInput(self): @@ -2982,7 +2984,7 @@ class ControlFlowTest(test.TestCase): result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32)) grad_theta = gradients_impl.gradients(result, theta) - if not control_flow_ops.ENABLE_WHILE_V2: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: with self.assertRaisesRegexp(TypeError, "Second-order gradient"): gradients_impl.gradients(grad_theta, theta) grad_theta_stopped = array_ops.stop_gradient(grad_theta) @@ -3514,7 +3516,7 @@ class ControlFlowTest(test.TestCase): self.assertEqual(r[1].eval(), 65536.0) self.assertEqual(grad.eval(), 524288.0) # while_v2 does not have stacks. - if not control_flow_ops.ENABLE_WHILE_V2: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: self.assertEqual( len([op for op in x.graph.get_operations() if op.type == "StackV2" ]), 1) diff --git a/tensorflow/python/kernel_tests/control_flow_util_v2_test.py b/tensorflow/python/kernel_tests/control_flow_util_v2_test.py index d0374a77005db4597ddbce76c1d2a3b9ac0e792d..08d3214e288bf873515f0b5a45ddf1e50ee1b281 100644 --- a/tensorflow/python/kernel_tests/control_flow_util_v2_test.py +++ b/tensorflow/python/kernel_tests/control_flow_util_v2_test.py @@ -23,6 +23,7 @@ from tensorflow.python.eager import function from tensorflow.python.framework import constant_op from tensorflow.python.framework import test_util from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import control_flow_util_v2 from tensorflow.python.platform import test @@ -30,14 +31,11 @@ from tensorflow.python.platform import test class ControlFlowUtilV2Test(test.TestCase): def setUp(self): - self._enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2 - self._enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2 - control_flow_ops.ENABLE_COND_V2 = True - control_flow_ops.ENABLE_WHILE_V2 = True + self._enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = True def tearDown(self): - control_flow_ops.ENABLE_COND_V2 = self._enable_cond_v2_old - control_flow_ops.ENABLE_WHILE_V2 = self._enable_while_v2_old + control_flow_util.ENABLE_CONTROL_FLOW_V2 = self._enable_control_flow_v2_old def _create_control_flow(self, expect_in_defun): """Helper method for testInDefun.""" diff --git a/tensorflow/python/kernel_tests/tensor_array_ops_test.py b/tensorflow/python/kernel_tests/tensor_array_ops_test.py index 88625841bcc982bf477b619f3da0b70498f0542f..6d8e3e83566ad30c8ec0ba9330e7b5c829b261a6 100644 --- a/tensorflow/python/kernel_tests/tensor_array_ops_test.py +++ b/tensorflow/python/kernel_tests/tensor_array_ops_test.py @@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import data_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import gradients_impl @@ -345,7 +346,7 @@ class TensorArrayTest(test.TestCase): @test_util.run_deprecated_v1 def testSkipEagerTensorArrayGradGrad(self): - if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: self.skipTest("Legacy TensorArray does not support double derivatives.") with self.test_session(use_gpu=True) as session: x = constant_op.constant(4.0) @@ -429,7 +430,7 @@ class TensorArrayTest(test.TestCase): with self.session(use_gpu=True): ta = _make_ta(3, "foo", dtype=dtypes.float32) # Test writing the wrong datatype - if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and + if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly()): error_msg = ("Invalid data types; op elements string but list elements " "float") @@ -440,7 +441,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError(error_msg): self.evaluate(ta.write(0, "wrong_type_scalar").flow) - if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and + if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly()): error_msg = "Trying to modify element -1 in a list with 3 elements." else: @@ -448,7 +449,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError(error_msg): self.evaluate(ta.write(-1, 3.0).flow) - if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and + if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly()): error_msg = "Trying to modify element 3 in a list with 3 elements" else: @@ -467,14 +468,14 @@ class TensorArrayTest(test.TestCase): # Test reading wrong datatype (only possible when constructing graphs). if (not context.executing_eagerly() and - not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2): + not control_flow_util.ENABLE_CONTROL_FLOW_V2): r0_bad = gen_data_flow_ops.tensor_array_read_v3( handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow) with self.assertRaisesOpError( "TensorArray dtype is float but Op requested dtype double."): self.evaluate(r0_bad) - if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and + if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly()): error_msg = "Trying to access element -1 in a list with 3 elements." else: @@ -483,7 +484,7 @@ class TensorArrayTest(test.TestCase): with self.assertRaisesOpError(error_msg): self.evaluate(ta.read(-1)) - if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and + if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly()): error_msg = "Trying to access element 3 in a list with 3 elements." else: @@ -550,7 +551,7 @@ class TensorArrayTest(test.TestCase): ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1}) error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1" - if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and + if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode else r"Expected sum of lengths to be equal to values.shape\[0\], " r"but sum of lengths is 1 and value's shape is: \[3\]") @@ -558,7 +559,7 @@ class TensorArrayTest(test.TestCase): self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow) ta = _make_ta(1, "baz") - if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and not in_eager_mode: + if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode: with self.assertRaisesRegexp( ValueError, "Shape must be at least rank 1 but is rank 0"): self.evaluate(ta.split(1.0, [1]).flow) @@ -568,7 +569,7 @@ class TensorArrayTest(test.TestCase): ): self.evaluate(ta.split(1.0, [1]).flow) - if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 or in_eager_mode: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode: ta = _make_ta(2, "buz") with self.assertRaisesOpError( r"TensorArray's size is not equal to the size of lengths " @@ -1003,21 +1004,6 @@ class TensorArrayTest(test.TestCase): # self._testWhileLoopWritePackGradients( # dynamic_size=False, dtype=tf.int64) - @test_util.disable_control_flow_v2("Testing v1 while_loop with v2 TA") - @test_util.enable_tensor_array_v2 - def testWhileLoopV1WithTensorArrayV2(self): - size = 3 - ta = tensor_array_ops.TensorArray( - dtype=dtypes.int32, size=size, element_shape=tensor_shape.scalar()) - - def Body(counter, ta): - return counter + 1, ta.write(counter, counter) - - _, ta = control_flow_ops.while_loop(lambda i, _: i < size, Body, [0, ta]) - - for i in range(size): - self.assertEqual(self.evaluate(ta.read(i)), i) - @test_util.disable_control_flow_v2("b/117943489 (dynamic_size)") @test_util.run_v1_only("b/117943489") def testSkipEagerWhileLoopDynamicWritePackGradients(self): @@ -1270,7 +1256,7 @@ class TensorArrayTest(test.TestCase): self.assertEqual((2, 2), w0.read(1).get_shape()) else: self.assertEqual(r0.get_shape().ndims, None) - if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2: + if not control_flow_util.ENABLE_CONTROL_FLOW_V2: self.assertEqual( tensor_shape.TensorShape( ta1.handle.op.get_attr("element_shape")).ndims, None) @@ -1347,8 +1333,8 @@ class TensorArrayTest(test.TestCase): "TensorArray has size zero, but element shape is not " "fully defined. Currently only static shapes are supported when " "packing zero-size TensorArrays.") - with self.assertRaisesOpError(v2_msg if tensor_array_ops - .ENABLE_TENSOR_ARRAY_V2 else v1_msg): + with self.assertRaisesOpError( + v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg): ta.stack().eval() @test_util.run_v1_only("b/120545219") diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index b7e50c1dae5ac1dc0968a3badb8f017e6b0384e1..99216d7fb15ff865ba70d01995606c6a5e3ab7c4 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -24,13 +24,11 @@ from __future__ import print_function import abc import collections import functools -import os import six from tensorflow.core.framework import attr_value_pb2 from tensorflow.core.protobuf import control_flow_pb2 -from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -71,9 +69,6 @@ cond_v2 = LazyLoader("cond_v2", globals(), while_v2 = LazyLoader("while_v2", globals(), "tensorflow.python.ops.while_v2") -ENABLE_COND_V2 = tf2.enabled() or os.getenv("TF_ENABLE_COND_V2", "0") != "0" -ENABLE_WHILE_V2 = tf2.enabled() or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" - # We override the 'tuple' for a control flow op, so we keep python's # existing 'tuple' for later use in this module. _basetuple = tuple @@ -2052,7 +2047,7 @@ def cond(pred, ``` """ - if ENABLE_COND_V2 and not context.executing_eagerly(): + if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly(): return cond_v2.cond_v2(pred, true_fn, false_fn, name) # We needed to make true_fn/false_fn keyword arguments for @@ -3487,7 +3482,7 @@ def while_loop(cond, ``` """ - if ENABLE_WHILE_V2 and not context.executing_eagerly(): + if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly(): return while_v2.while_loop( cond, body, diff --git a/tensorflow/python/ops/control_flow_ops_benchmark.py b/tensorflow/python/ops/control_flow_ops_benchmark.py index 9ba5ff2c0f8af44e8536b49a3c0e7ef6bfae4d28..9dd1e6673b854c3cbc248f0e5a5be4c67d2bd72c 100644 --- a/tensorflow/python/ops/control_flow_ops_benchmark.py +++ b/tensorflow/python/ops/control_flow_ops_benchmark.py @@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.platform import test @@ -94,28 +95,28 @@ class CondWithManyIntermediatesBenchmark(test.Benchmark): iters=self.NUM_ITERS) def benchmark_cond_v1_defun(self): - old_val = control_flow_ops.ENABLE_COND_V2 - control_flow_ops.ENABLE_COND_V2 = False + old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = False self._benchmark_defun() - control_flow_ops.ENABLE_COND_V2 = old_val + control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val def benchmark_cond_v2_defun(self): - old_val = control_flow_ops.ENABLE_COND_V2 - control_flow_ops.ENABLE_COND_V2 = True + old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = True self._benchmark_defun() - control_flow_ops.ENABLE_COND_V2 = old_val + control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val def benchmark_cond_v1_graph(self): - old_val = control_flow_ops.ENABLE_COND_V2 - control_flow_ops.ENABLE_COND_V2 = False + old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = False self._benchmark_graph() - control_flow_ops.ENABLE_COND_V2 = old_val + control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val def benchmark_cond_v2_graph(self): - old_val = control_flow_ops.ENABLE_COND_V2 - control_flow_ops.ENABLE_COND_V2 = True + old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2 + control_flow_util.ENABLE_CONTROL_FLOW_V2 = True self._benchmark_graph() - control_flow_ops.ENABLE_COND_V2 = old_val + control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val if __name__ == "__main__": ops.enable_eager_execution() diff --git a/tensorflow/python/ops/control_flow_util.py b/tensorflow/python/ops/control_flow_util.py index cb628f4aa6441ec9cb03dfe873a79d06a66e37a1..1747f06109daa1e7092fd1bbbcd2e2cc5762fc6c 100644 --- a/tensorflow/python/ops/control_flow_util.py +++ b/tensorflow/python/ops/control_flow_util.py @@ -23,10 +23,18 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import traceback +from tensorflow.python import tf2 from tensorflow.python.platform import tf_logging as logging +ENABLE_CONTROL_FLOW_V2 = (tf2.enabled() or + os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or + os.getenv("TF_ENABLE_COND_V2", "0") != "0" or + os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or + os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0") + def IsInXLAContext(op): try: diff --git a/tensorflow/python/ops/tensor_array_ops.py b/tensorflow/python/ops/tensor_array_ops.py index d1516949517f1f5df9291add96756eeacea29f51..85333ee6b561c2c593eed3b12caff419eb7c1c84 100644 --- a/tensorflow/python/ops/tensor_array_ops.py +++ b/tensorflow/python/ops/tensor_array_ops.py @@ -20,10 +20,8 @@ from __future__ import division from __future__ import print_function import contextlib -import os import weakref -from tensorflow.python import tf2 from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -32,6 +30,7 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_util from tensorflow.python.ops import gen_control_flow_ops from tensorflow.python.ops import gen_data_flow_ops from tensorflow.python.ops import list_ops @@ -40,10 +39,6 @@ from tensorflow.python.util import tf_should_use from tensorflow.python.util.tf_export import tf_export -ENABLE_TENSOR_ARRAY_V2 = ( - tf2.enabled() or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2") is not None) - - # _GraphTensorArray accesses many of the hidden generated ops, but is in # fact built to wrap these methods. # pylint: disable=protected-access @@ -1013,7 +1008,7 @@ class TensorArray(object): if context.executing_eagerly(): implementation = _EagerTensorArray else: - if ENABLE_TENSOR_ARRAY_V2: + if control_flow_util.ENABLE_CONTROL_FLOW_V2: implementation = _GraphTensorArrayV2 else: implementation = _GraphTensorArray diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index d00c158d156b225553b52437324accd019c76aee..f7566bac9bd290600d24b40875d1b678d8e7ce07 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -52,13 +52,6 @@ from tensorflow.python.util import nest # to them and then pass those in as data inputs. This should probably be # handled in the CapturingGraph itself. -# Op types that output a resource tensor representing a TensorArray handle. -TENSOR_ARRAY_HANDLE_OPS = ( - "TensorArrayV3", - "TensorArrayGradV3", - "TensorArrayGradWithShape", -) - def while_loop(cond, body, @@ -257,24 +250,19 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name "_maximum_iterations") if _is_in_xla_context() else None assert not _is_in_xla_context() or maximum_iterations is not None - # Set the incoming gradient of TensorArray handles to None. The gradient - # implementation currently assumes all resource tensors correspond to float32 - # ResourceVariables, which can lead to runtime shape errors when used with a - # TensorArray. This is a workaround until TensorArrays are reimplemented with - # TensorLists instead of resources. - # Also set the incoming gradient of non-trainable inputs to None. It is - # possible that we receive non-None gradients for non-trainable types in - # nested while loops because we accumulate outputs of the inner while as - # variant tensors which are trainable and hence receive zeros_like tensors in - # the gradient pass. The non-trainable tensors then receive the popped zeros - # tensor from this zeros variant. The gradient for the loop vars corresponding - # to these tensors is None or zeros (this happens only if the loop var is - # accumulated as well) in _grad_fn so we reset these. + # Set the incoming gradient of non-trainable inputs to None. It is possible + # that we receive non-None gradients for non-trainable types in nested while + # loops because we accumulate outputs of the inner while as variant tensors + # which are trainable and hence receive zeros_like tensors in the gradient + # pass. The non-trainable tensors then receive the popped zeros tensor from + # this zeros variant. The gradient for the loop vars corresponding to these + # tensors is None or zeros (this happens only if the loop var is accumulated + # as well) in _grad_fn so we reset these. # TODO(b/118712257): Remove the IsTrainable filter once we can handle None # output grads in _grad_fn. grads = [ - None if _is_tensor_array_handle(output) or not _is_trainable(output) - else grad for grad, output in zip(grads, body_graph.outputs) + None if not _is_trainable(output) else grad + for grad, output in zip(grads, body_graph.outputs) ] # Ensure that all non-resource trainable outputs have incoming gradients. @@ -339,8 +327,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name # See comment in while_loop. outputs = [array_ops.identity(t) for t in outputs] - # Set None as the output gradient for tensors with None input gradient - # e.g. TensorArray handles. + # Set None as the output gradient for tensors with None input gradient. # outputs[0] is the loop counter. # outputs[1] is the total number of loop iterations. index = 2 @@ -853,28 +840,6 @@ def _graph_name(graph): return "Base" -def _is_tensor_array_handle(tensor): - """Returns whether tensor is a TensorArray handle.""" - if tensor.dtype != dtypes.resource: - return False - - if tensor.op.type == "While": - # We assume that any resource outputs of a While op correspond to a captured - # resource input (as opposed to a loop variable specified by the user). - # NOTE(skyewm): we could actually check this, but I can't think of when you - # would have a resource loop variable. - tensor = tensor.op.inputs[tensor.value_index] - - # TODO(b/118452219): add test coverage for this. - tensor = func_graph_module.maybe_captured(tensor) - - if isinstance(tensor, ops.EagerTensor): - # Eager execution doesn't quite support legacy tensorarray - return False - - return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS - - def _pack_sequence_as(structure_with_tas, loop_vars): """Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""