提交 28fd3038 编写于 作者: Y Yuan Yu 提交者: TensorFlower Gardener

Set up the generated gradient while loops properly so that `None` is...

Set up the generated gradient while loops properly so that `None` is propagated when using tf.stop_gradient.
Change: 144108792
上级 6a661a02
......@@ -2069,6 +2069,22 @@ class ControlFlowTest(test.TestCase):
grad_theta_stopped = array_ops.stop_gradient(grad_theta)
gradients_impl.gradients(grad_theta_stopped, theta)
def testStopGradOnWhileGrad(self):
with self.test_session():
x = constant_op.constant(2.0, name="x")
y = constant_op.constant(2.0, name="y")
c = lambda x: math_ops.less(x, 100.0)
b = lambda x: math_ops.mul(x, y)
rx = control_flow_ops.while_loop(c, b, [x])
rg = gradients_impl.gradients(rx, y)[0]
rg = array_ops.stop_gradient(rg)
r = math_ops.add(math_ops.square(y), rx)
r = math_ops.add(r, rg)
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
def testOneValueCond(self):
with self.test_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
......
......@@ -203,7 +203,7 @@ def _EnterGrad(op, grad):
# Skip gradient computation, if the attribute `back_prop` is false.
return grad
if grad_ctxt.grad_state is None:
# Pass the gradient grough if we are not in a gradient while context.
# Pass the gradient through if we are not in a gradient while context.
return grad
if op.get_attr("is_constant"):
# Add a gradient accumulator for each loop invariant.
......@@ -216,6 +216,7 @@ def _EnterGrad(op, grad):
raise TypeError("Type %s not supported" % type(grad))
else:
result = exit(grad)
grad_ctxt.loop_exits.append(result)
grad_ctxt.ExitResult([result])
return result
......
......@@ -703,6 +703,7 @@ class GradLoopState(object):
self._switch_map = {}
self._unused_exits = []
self._deferred_exits = []
self._forward_loop_exits = list(forward_ctxt.loop_exits)
self._pending_exits_count = len(forward_ctxt.loop_exits)
self._outer_grad_state = outer_grad_state
......@@ -820,6 +821,11 @@ class GradLoopState(object):
"""The list of "deferred" exits."""
return self._deferred_exits
@property
def forward_loop_exits(self):
"""The list of exits of the forward loop."""
return self._forward_loop_exits
@property
def pending_exits_count(self):
"""The number of exits we expect to see but haven't."""
......@@ -1059,8 +1065,8 @@ class ControlFlowState(object):
to backprop.
"""
loop_exits = []
for forward_ctxt, grad_state in self._map.items():
for y in forward_ctxt.loop_exits:
for _, grad_state in self._map.items():
for y in grad_state.forward_loop_exits:
# pylint: disable=protected-access
if pending_count[y.op._id] == 0:
grad_state.pending_exits_count -= 1
......@@ -1105,7 +1111,7 @@ class ControlFlowState(object):
self._map[forward_ctxt] = grad_state
# We need to include all exits of a loop for backprop.
for loop_exit in forward_ctxt.loop_exits:
for loop_exit in grad_state.forward_loop_exits:
if not between_ops[loop_exit.op._id]:
between_ops[loop_exit.op._id] = True
between_op_list.append(loop_exit.op)
......@@ -2119,6 +2125,7 @@ class WhileContext(ControlFlowContext):
merge_n.op._update_input(1, next_n)
total_iterations = exit(switch_n[0], name="f_count")
self.loop_exits.append(total_iterations)
self.ExitResult([total_iterations])
self.Exit()
return total_iterations, next_n
......@@ -2163,6 +2170,7 @@ class WhileContext(ControlFlowContext):
merge_count.op._update_input(1, next_count)
final_zero = exit(switch_count[0], name="b_count")
self.loop_exits.append(final_zero)
if outer_grad_state is not None:
# Force the stack pops of i-th execution of an inner loop to be ordered
# before the pops of (i+1)-th execution of the same inner loop.
......@@ -2244,6 +2252,7 @@ class WhileContext(ControlFlowContext):
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
acc_result = exit(switch_acc_false, name="b_acc")
self.loop_exits.append(acc_result)
self.ExitResult([acc_result])
return acc_result
......@@ -2320,6 +2329,7 @@ class WhileContext(ControlFlowContext):
xm.op._update_input(1, xn) # pylint: disable=protected-access
acc_exits = [exit(x[0], name="b_acc") for x in switch_acc]
self.loop_exits.extend(acc_exits)
self.ExitResult(acc_exits)
return ops.IndexedSlices(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册