提交 1d54cbf4 编写于 作者: S Skye Wanderman-Milne 提交者: TensorFlower Gardener

Introduce consolidated ENABLE_CONTROL_FLOW_V2 flag.

The new toggle replaces ENABLE_COND_V2, ENABLE_WHILE_V2, and
ENABLE_TENSOR_ARRAY_V2. This means that these can't be toggled
independently anymore, notably that v1 TensorArrays can only be run
with v1 loops, and v2 TensorArrays with v2 loops.

This also introduces a corresponding environment variable
TF_ENABLE_CONTROL_FLOW_V2. I kept the old env vars as well in case
people are using them. They all flip the new single toggle now.

In addition, this change removes some while_v2 code for dealing with
v1 TensorArrays, since this is no longer a supported configuration.

PiperOrigin-RevId: 224862245
上级 ee418c8e
......@@ -32,6 +32,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
......@@ -500,10 +501,10 @@ class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
def testMapAndBatchControlFlow(self, numa_aware):
def map_fn(x):
previous_cond_v2_value = control_flow_ops.ENABLE_COND_V2
control_flow_ops.ENABLE_COND_V2 = True
previous_control_flow_v2_value = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
return_value = control_flow_ops.cond(x < 50, lambda: x + 1, lambda: x * x)
control_flow_ops.ENABLE_COND_V2 = previous_cond_v2_value
control_flow_util.ENABLE_CONTROL_FLOW_V2 = previous_control_flow_v2_value
return return_value
dataset = dataset_ops.Dataset.range(100).apply(
......
......@@ -67,9 +67,8 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
......@@ -409,42 +408,12 @@ def enable_control_flow_v2(fn):
"""
def wrapper(*args, **kwargs):
enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
control_flow_ops.ENABLE_COND_V2 = True
control_flow_ops.ENABLE_WHILE_V2 = True
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
try:
fn(*args, **kwargs)
finally:
control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
return wrapper
def enable_tensor_array_v2(fn):
"""Decorator for enabling _GraphTensorArrayV2 on a test.
Note this enables _GraphTensorArrayV2 after running the test class's
setup/teardown methods.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
def wrapper(*args, **kwargs):
enable_tensor_array_v2_old = tensor_array_ops.ENABLE_TENSOR_ARRAY_V2
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = True
try:
fn(*args, **kwargs)
finally:
tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 = enable_tensor_array_v2_old
control_flow_util.ENABLE_CONTROL_FLOW_V2 = enable_control_flow_v2_old
return wrapper
......@@ -493,7 +462,7 @@ def with_control_flow_v2(cls):
Returns:
cls with new test methods added
"""
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
return cls
for name, value in cls.__dict__.copy().items():
......
......@@ -43,6 +43,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import gen_array_ops
......@@ -700,7 +701,8 @@ class ControlFlowTest(test.TestCase):
v1_msg = "The two structures don't have the same nested structure"
v2_msg = "Outputs of true_fn and false_fn must have the same structure"
with self.assertRaisesRegexp(
ValueError, v2_msg if control_flow_ops.ENABLE_COND_V2 else v1_msg):
ValueError,
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
r = control_flow_ops.cond(pred, fn1, fn2)
self.evaluate(r)
......@@ -859,7 +861,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(sess.run(grad, {pred: False, x: 1.0, y: 2.0}), 0.0)
# v1 control flow gets None second derivative for some reason.
if not control_flow_ops.ENABLE_COND_V2:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertIsNone(grad_grad)
return
......@@ -949,7 +951,7 @@ class ControlFlowTest(test.TestCase):
# In defuns, all prints should execute in program order.
# This doesn't work with legacy control flow.
if control_flow_ops.ENABLE_COND_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
@eager_function.defun
def cond():
......@@ -1003,7 +1005,7 @@ class ControlFlowTest(test.TestCase):
# In defuns, all prints should execute in program order.
# This doesn't work with legacy control flow.
if control_flow_ops.ENABLE_WHILE_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
@eager_function.defun
def while_loop():
......@@ -1161,7 +1163,7 @@ class ControlFlowTest(test.TestCase):
gs = gradients_impl.gradients(loop_no_xla, v)
self.evaluate(gs) # This should execute without error.
if control_flow_ops.ENABLE_WHILE_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
with self.assertRaisesRegexp(
......@@ -1219,7 +1221,7 @@ class ControlFlowTest(test.TestCase):
lambda i, x: (i + 1, v * x), (0, 1.0),
maximum_iterations=max_iter_holder[0])
if control_flow_ops.ENABLE_WHILE_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
xla_context = control_flow_ops.XLAControlFlowContext()
xla_context.Enter()
with self.assertRaisesRegexp(
......@@ -1863,7 +1865,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(sess.run(grad, {pred: True}), 8.0)
self.assertEqual(sess.run(grad, {pred: False}), 0.0)
if not control_flow_ops.ENABLE_WHILE_V2:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
return
self.assertEqual(sess.run(grad_grad, {pred: True}), 0.0)
......@@ -2399,7 +2401,7 @@ class ControlFlowTest(test.TestCase):
# outer_loop(x) = g(g(x)) = 4x + 81
# outer_loop'(x) = 4
# Note that v1 control flow gets 4.0 as well if the cond is removed.
if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual(grad, 4.0)
def testWhile_NestedInput(self):
......@@ -2982,7 +2984,7 @@ class ControlFlowTest(test.TestCase):
result = functional_ops.scan(fn, np.array([1., 2., 3.], dtype=np.float32))
grad_theta = gradients_impl.gradients(result, theta)
if not control_flow_ops.ENABLE_WHILE_V2:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
with self.assertRaisesRegexp(TypeError, "Second-order gradient"):
gradients_impl.gradients(grad_theta, theta)
grad_theta_stopped = array_ops.stop_gradient(grad_theta)
......@@ -3514,7 +3516,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(r[1].eval(), 65536.0)
self.assertEqual(grad.eval(), 524288.0)
# while_v2 does not have stacks.
if not control_flow_ops.ENABLE_WHILE_V2:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual(
len([op for op in x.graph.get_operations() if op.type == "StackV2"
]), 1)
......
......@@ -23,6 +23,7 @@ from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import control_flow_util_v2
from tensorflow.python.platform import test
......@@ -30,14 +31,11 @@ from tensorflow.python.platform import test
class ControlFlowUtilV2Test(test.TestCase):
def setUp(self):
self._enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
self._enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
control_flow_ops.ENABLE_COND_V2 = True
control_flow_ops.ENABLE_WHILE_V2 = True
self._enable_control_flow_v2_old = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
def tearDown(self):
control_flow_ops.ENABLE_COND_V2 = self._enable_cond_v2_old
control_flow_ops.ENABLE_WHILE_V2 = self._enable_while_v2_old
control_flow_util.ENABLE_CONTROL_FLOW_V2 = self._enable_control_flow_v2_old
def _create_control_flow(self, expect_in_defun):
"""Helper method for testInDefun."""
......
......@@ -32,6 +32,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import data_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gradients_impl
......@@ -345,7 +346,7 @@ class TensorArrayTest(test.TestCase):
@test_util.run_deprecated_v1
def testSkipEagerTensorArrayGradGrad(self):
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.skipTest("Legacy TensorArray does not support double derivatives.")
with self.test_session(use_gpu=True) as session:
x = constant_op.constant(4.0)
......@@ -429,7 +430,7 @@ class TensorArrayTest(test.TestCase):
with self.session(use_gpu=True):
ta = _make_ta(3, "foo", dtype=dtypes.float32)
# Test writing the wrong datatype
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = ("Invalid data types; op elements string but list elements "
"float")
......@@ -440,7 +441,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.write(0, "wrong_type_scalar").flow)
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to modify element -1 in a list with 3 elements."
else:
......@@ -448,7 +449,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.write(-1, 3.0).flow)
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to modify element 3 in a list with 3 elements"
else:
......@@ -467,14 +468,14 @@ class TensorArrayTest(test.TestCase):
# Test reading wrong datatype (only possible when constructing graphs).
if (not context.executing_eagerly() and
not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2):
not control_flow_util.ENABLE_CONTROL_FLOW_V2):
r0_bad = gen_data_flow_ops.tensor_array_read_v3(
handle=w0.handle, index=0, dtype=dtypes.float64, flow_in=w0.flow)
with self.assertRaisesOpError(
"TensorArray dtype is float but Op requested dtype double."):
self.evaluate(r0_bad)
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to access element -1 in a list with 3 elements."
else:
......@@ -483,7 +484,7 @@ class TensorArrayTest(test.TestCase):
with self.assertRaisesOpError(error_msg):
self.evaluate(ta.read(-1))
if (tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
if (control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not context.executing_eagerly()):
error_msg = "Trying to access element 3 in a list with 3 elements."
else:
......@@ -550,7 +551,7 @@ class TensorArrayTest(test.TestCase):
ta.split([1.0, 2.0, 3.0], lengths).flow.eval(feed_dict={lengths: 1})
error_msg = ("Unused values in tensor. Length of tensor: 3 Values used: 1"
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and
not in_eager_mode else
r"Expected sum of lengths to be equal to values.shape\[0\], "
r"but sum of lengths is 1 and value's shape is: \[3\]")
......@@ -558,7 +559,7 @@ class TensorArrayTest(test.TestCase):
self.evaluate(ta.split([1.0, 2.0, 3.0], [1]).flow)
ta = _make_ta(1, "baz")
if tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 and not in_eager_mode:
if control_flow_util.ENABLE_CONTROL_FLOW_V2 and not in_eager_mode:
with self.assertRaisesRegexp(
ValueError, "Shape must be at least rank 1 but is rank 0"):
self.evaluate(ta.split(1.0, [1]).flow)
......@@ -568,7 +569,7 @@ class TensorArrayTest(test.TestCase):
):
self.evaluate(ta.split(1.0, [1]).flow)
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2 or in_eager_mode:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2 or in_eager_mode:
ta = _make_ta(2, "buz")
with self.assertRaisesOpError(
r"TensorArray's size is not equal to the size of lengths "
......@@ -1003,21 +1004,6 @@ class TensorArrayTest(test.TestCase):
# self._testWhileLoopWritePackGradients(
# dynamic_size=False, dtype=tf.int64)
@test_util.disable_control_flow_v2("Testing v1 while_loop with v2 TA")
@test_util.enable_tensor_array_v2
def testWhileLoopV1WithTensorArrayV2(self):
size = 3
ta = tensor_array_ops.TensorArray(
dtype=dtypes.int32, size=size, element_shape=tensor_shape.scalar())
def Body(counter, ta):
return counter + 1, ta.write(counter, counter)
_, ta = control_flow_ops.while_loop(lambda i, _: i < size, Body, [0, ta])
for i in range(size):
self.assertEqual(self.evaluate(ta.read(i)), i)
@test_util.disable_control_flow_v2("b/117943489 (dynamic_size)")
@test_util.run_v1_only("b/117943489")
def testSkipEagerWhileLoopDynamicWritePackGradients(self):
......@@ -1270,7 +1256,7 @@ class TensorArrayTest(test.TestCase):
self.assertEqual((2, 2), w0.read(1).get_shape())
else:
self.assertEqual(r0.get_shape().ndims, None)
if not tensor_array_ops.ENABLE_TENSOR_ARRAY_V2:
if not control_flow_util.ENABLE_CONTROL_FLOW_V2:
self.assertEqual(
tensor_shape.TensorShape(
ta1.handle.op.get_attr("element_shape")).ndims, None)
......@@ -1347,8 +1333,8 @@ class TensorArrayTest(test.TestCase):
"TensorArray has size zero, but element shape <unknown> is not "
"fully defined. Currently only static shapes are supported when "
"packing zero-size TensorArrays.")
with self.assertRaisesOpError(v2_msg if tensor_array_ops
.ENABLE_TENSOR_ARRAY_V2 else v1_msg):
with self.assertRaisesOpError(
v2_msg if control_flow_util.ENABLE_CONTROL_FLOW_V2 else v1_msg):
ta.stack().eval()
@test_util.run_v1_only("b/120545219")
......
......@@ -24,13 +24,11 @@ from __future__ import print_function
import abc
import collections
import functools
import os
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.protobuf import control_flow_pb2
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
......@@ -71,9 +69,6 @@ cond_v2 = LazyLoader("cond_v2", globals(),
while_v2 = LazyLoader("while_v2", globals(),
"tensorflow.python.ops.while_v2")
ENABLE_COND_V2 = tf2.enabled() or os.getenv("TF_ENABLE_COND_V2", "0") != "0"
ENABLE_WHILE_V2 = tf2.enabled() or os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
# existing 'tuple' for later use in this module.
_basetuple = tuple
......@@ -2052,7 +2047,7 @@ def cond(pred,
```
"""
if ENABLE_COND_V2 and not context.executing_eagerly():
if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
return cond_v2.cond_v2(pred, true_fn, false_fn, name)
# We needed to make true_fn/false_fn keyword arguments for
......@@ -3487,7 +3482,7 @@ def while_loop(cond,
```
"""
if ENABLE_WHILE_V2 and not context.executing_eagerly():
if util.ENABLE_CONTROL_FLOW_V2 and not context.executing_eagerly():
return while_v2.while_loop(
cond,
body,
......
......@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
......@@ -94,28 +95,28 @@ class CondWithManyIntermediatesBenchmark(test.Benchmark):
iters=self.NUM_ITERS)
def benchmark_cond_v1_defun(self):
old_val = control_flow_ops.ENABLE_COND_V2
control_flow_ops.ENABLE_COND_V2 = False
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
self._benchmark_defun()
control_flow_ops.ENABLE_COND_V2 = old_val
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
def benchmark_cond_v2_defun(self):
old_val = control_flow_ops.ENABLE_COND_V2
control_flow_ops.ENABLE_COND_V2 = True
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
self._benchmark_defun()
control_flow_ops.ENABLE_COND_V2 = old_val
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
def benchmark_cond_v1_graph(self):
old_val = control_flow_ops.ENABLE_COND_V2
control_flow_ops.ENABLE_COND_V2 = False
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = False
self._benchmark_graph()
control_flow_ops.ENABLE_COND_V2 = old_val
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
def benchmark_cond_v2_graph(self):
old_val = control_flow_ops.ENABLE_COND_V2
control_flow_ops.ENABLE_COND_V2 = True
old_val = control_flow_util.ENABLE_CONTROL_FLOW_V2
control_flow_util.ENABLE_CONTROL_FLOW_V2 = True
self._benchmark_graph()
control_flow_ops.ENABLE_COND_V2 = old_val
control_flow_util.ENABLE_CONTROL_FLOW_V2 = old_val
if __name__ == "__main__":
ops.enable_eager_execution()
......
......@@ -23,10 +23,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import traceback
from tensorflow.python import tf2
from tensorflow.python.platform import tf_logging as logging
ENABLE_CONTROL_FLOW_V2 = (tf2.enabled() or
os.getenv("TF_ENABLE_CONTROL_FLOW_V2", "0") != "0" or
os.getenv("TF_ENABLE_COND_V2", "0") != "0" or
os.getenv("TF_ENABLE_WHILE_V2", "0") != "0" or
os.getenv("TF_ENABLE_TENSOR_ARRAY_V2", "0") != "0")
def IsInXLAContext(op):
try:
......
......@@ -20,10 +20,8 @@ from __future__ import division
from __future__ import print_function
import contextlib
import os
import weakref
from tensorflow.python import tf2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
......@@ -32,6 +30,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import list_ops
......@@ -40,10 +39,6 @@ from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
ENABLE_TENSOR_ARRAY_V2 = (
tf2.enabled() or os.getenv("TF_ENABLE_TENSOR_ARRAY_V2") is not None)
# _GraphTensorArray accesses many of the hidden generated ops, but is in
# fact built to wrap these methods.
# pylint: disable=protected-access
......@@ -1013,7 +1008,7 @@ class TensorArray(object):
if context.executing_eagerly():
implementation = _EagerTensorArray
else:
if ENABLE_TENSOR_ARRAY_V2:
if control_flow_util.ENABLE_CONTROL_FLOW_V2:
implementation = _GraphTensorArrayV2
else:
implementation = _GraphTensorArray
......
......@@ -52,13 +52,6 @@ from tensorflow.python.util import nest
# to them and then pass those in as data inputs. This should probably be
# handled in the CapturingGraph itself.
# Op types that output a resource tensor representing a TensorArray handle.
TENSOR_ARRAY_HANDLE_OPS = (
"TensorArrayV3",
"TensorArrayGradV3",
"TensorArrayGradWithShape",
)
def while_loop(cond,
body,
......@@ -257,24 +250,19 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
"_maximum_iterations") if _is_in_xla_context() else None
assert not _is_in_xla_context() or maximum_iterations is not None
# Set the incoming gradient of TensorArray handles to None. The gradient
# implementation currently assumes all resource tensors correspond to float32
# ResourceVariables, which can lead to runtime shape errors when used with a
# TensorArray. This is a workaround until TensorArrays are reimplemented with
# TensorLists instead of resources.
# Also set the incoming gradient of non-trainable inputs to None. It is
# possible that we receive non-None gradients for non-trainable types in
# nested while loops because we accumulate outputs of the inner while as
# variant tensors which are trainable and hence receive zeros_like tensors in
# the gradient pass. The non-trainable tensors then receive the popped zeros
# tensor from this zeros variant. The gradient for the loop vars corresponding
# to these tensors is None or zeros (this happens only if the loop var is
# accumulated as well) in _grad_fn so we reset these.
# Set the incoming gradient of non-trainable inputs to None. It is possible
# that we receive non-None gradients for non-trainable types in nested while
# loops because we accumulate outputs of the inner while as variant tensors
# which are trainable and hence receive zeros_like tensors in the gradient
# pass. The non-trainable tensors then receive the popped zeros tensor from
# this zeros variant. The gradient for the loop vars corresponding to these
# tensors is None or zeros (this happens only if the loop var is accumulated
# as well) in _grad_fn so we reset these.
# TODO(b/118712257): Remove the IsTrainable filter once we can handle None
# output grads in _grad_fn.
grads = [
None if _is_tensor_array_handle(output) or not _is_trainable(output)
else grad for grad, output in zip(grads, body_graph.outputs)
None if not _is_trainable(output) else grad
for grad, output in zip(grads, body_graph.outputs)
]
# Ensure that all non-resource trainable outputs have incoming gradients.
......@@ -339,8 +327,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
# See comment in while_loop.
outputs = [array_ops.identity(t) for t in outputs]
# Set None as the output gradient for tensors with None input gradient
# e.g. TensorArray handles.
# Set None as the output gradient for tensors with None input gradient.
# outputs[0] is the loop counter.
# outputs[1] is the total number of loop iterations.
index = 2
......@@ -853,28 +840,6 @@ def _graph_name(graph):
return "Base"
def _is_tensor_array_handle(tensor):
"""Returns whether tensor is a TensorArray handle."""
if tensor.dtype != dtypes.resource:
return False
if tensor.op.type == "While":
# We assume that any resource outputs of a While op correspond to a captured
# resource input (as opposed to a loop variable specified by the user).
# NOTE(skyewm): we could actually check this, but I can't think of when you
# would have a resource loop variable.
tensor = tensor.op.inputs[tensor.value_index]
# TODO(b/118452219): add test coverage for this.
tensor = func_graph_module.maybe_captured(tensor)
if isinstance(tensor, ops.EagerTensor):
# Eager execution doesn't quite support legacy tensorarray
return False
return tensor.op.type in TENSOR_ARRAY_HANDLE_OPS
def _pack_sequence_as(structure_with_tas, loop_vars):
"""Like `nest.pack_sequence_as` but also replaces flows with TensorArrays."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册