提交 0ab76c2c 编写于 作者: Y Yuan Yu 提交者: TensorFlower Gardener

Make sure the zeros tensor for gradient accumulation for a loop invariant is...

Make sure the zeros tensor for gradient accumulation for a loop invariant is initialized with the right shape. This is needed because this zeros tensor can get passed downstream when the loop exits with no iterations.
Change: 125620906
上级 fee00658
......@@ -950,6 +950,18 @@ class ControlFlowTest(tf.test.TestCase):
self.assertEqual([None], r.get_shape().as_list())
self.assertAllClose([810.0, 2560.0], r.eval(feed_dict={x: [3.0, 4.0]}))
def testWhileGrad_BaseShape(self):
with self.test_session() as sess:
x = tf.placeholder(tf.float32, [None])
v0 = tf.constant([2.0, 2.0], name="v")
c = lambda v: tf.constant(False)
b = lambda v: tf.mul(v, x)
r = tf.while_loop(c, b, [v0])
y = tf.square(x)
r = tf.gradients([r, y], x)[0]
self.assertAllClose([2.0, 4.0], sess.run(r, feed_dict={x: [1.0, 2.0]}))
def testWhileGrad_MultipleUses(self):
with self.test_session():
v = tf.constant(2.0, name="v")
......
......@@ -199,7 +199,7 @@ def _EnterGrad(op, grad):
if op.get_attr("is_constant"):
# Add a gradient accumulator for each loop invariant.
if isinstance(grad, ops.Tensor):
result = grad_ctxt.AddBackPropAccumulator(grad)
result = grad_ctxt.AddBackPropAccumulator(op, grad)
elif isinstance(grad, ops.IndexedSlices):
result = grad_ctxt.AddBackPropIndexedSlicesAccumulator(grad)
else:
......
......@@ -773,13 +773,13 @@ class GradLoopState(object):
# Record the history of this value in forward_ctxt.
# TODO(yuanbyu): Avoid recording constants.
self._grad_context.Exit()
h_value = cur_grad_state.AddForwardAccumulator(cur_value)
history_value = cur_grad_state.AddForwardAccumulator(cur_value)
self._grad_context.Enter()
break
if real_value is None:
# Add the stack pop op in the grad context.
real_value = self.AddBackPropAccumulatedValue(h_value, value)
real_value = self.AddBackPropAccumulatedValue(history_value, value)
self._history_map[value.name] = real_value
return real_value
......@@ -966,13 +966,13 @@ class ControlFlowState(object):
# Add forward accumulator for shape.
grad_state.grad_context.Exit()
h_shape = grad_state.AddForwardAccumulator(
history_zeros_shape = grad_state.AddForwardAccumulator(
zeros_shape, dead_branch=dead_branch)
grad_state.grad_context.Enter()
# Create a zero tensor with the right shape.
shape = grad_state.AddBackPropAccumulatedValue(
h_shape, zeros_shape, dead_branch)
history_zeros_shape, zeros_shape, dead_branch)
result = array_ops.zeros(shape, val.dtype)
return result
......@@ -1596,36 +1596,56 @@ class WhileContext(ControlFlowContext):
self.Exit()
return next_count
def AddBackPropAccumulator(self, value):
def AddBackPropAccumulator(self, op, grad):
"""Add an accumulation loop for every loop invariant.
This is added to the backprop loop. It is used to accumulate
partial gradients within each loop iteration. Called when in the
gradient while context.
This is added to the backprop loop. It is used to accumulate partial
gradients within each loop iteration. Called when in the gradient while
context.
The pseudocode is:
```
acc = 0.0;
while (_pivot) {
acc += value;
acc += grad;
}
```
Args:
value: The partial gradient of an iteration for a loop invariant.
op: The Enter op for a loop invariant.
grad: The partial gradient of an iteration for a loop invariant.
Returns:
The gradient for a loop invariant.
"""
self.Exit()
shape = value.get_shape()
if not shape.is_fully_defined():
shape = None
if self.outer_context: self.outer_context.Enter()
acc = constant_op.constant(0, value.dtype, shape=shape, name="b_acc")
if not shape:
acc._shape = value.get_shape() # pylint: disable=protected-access
if self.outer_context: self.outer_context.Exit()
# Create a zeros tensor with the right shape for acc. If we don't
# know the full shape statically, we will have to get the shape
# dynamically from the forward inference. Getting the shape right
# for the zeros is only needed for the base case when the loop exits
# without running any iterations.
shape = grad.get_shape()
if shape.is_fully_defined():
if self.outer_context: self.outer_context.Enter()
acc = constant_op.constant(0, grad.dtype, shape=shape, name="b_acc")
if self.outer_context: self.outer_context.Exit()
else:
value = op.inputs[0]
if self.outer_context:
forward_ctxt = self.grad_state.forward_ctxt
forward_ctxt.outer_context.Enter()
zeros_shape = array_ops.shape(value)
forward_ctxt.outer_context.Exit()
history_zeros_shape = grad_state.AddForwardAccumulator(zeros_shape)
self.outer_context.Enter()
real_shape = outer_grad_state.AddBackPropAccumulatedValue(
history_zeros_shape, zeros_shape)
acc = array_ops.zeros(real_shape, grad.dtype)
self.outer_context.Exit()
else:
zeros_shape = array_ops.shape(value)
acc = array_ops.zeros(zeros_shape, grad.dtype)
acc._shape = grad.get_shape() # pylint: disable=protected-access
self.Enter()
self.AddName(acc.name)
......@@ -1633,30 +1653,30 @@ class WhileContext(ControlFlowContext):
parallel_iterations=self._parallel_iterations,
name="b_acc")
merge_acc = merge([enter_acc, enter_acc], name="b_acc")[0]
switch_acc = switch(merge_acc, self._pivot)
switch_acc_false, switch_acc_true = switch(merge_acc, self._pivot)
add_acc = math_ops.add(switch_acc[1], value)
add_acc = math_ops.add(switch_acc_true, grad)
next_acc = _NextIteration(add_acc)
merge_acc.op._update_input(1, next_acc) # pylint: disable=protected-access
acc_result = exit(switch_acc[0], name="b_acc")
acc_result = exit(switch_acc_false, name="b_acc")
self.ExitResult([acc_result])
return acc_result
def AddBackPropIndexedSlicesAccumulator(self, value):
def AddBackPropIndexedSlicesAccumulator(self, grad):
"""This is used for accumulating gradients that are IndexedSlices.
This is essentially the equavalent of AddBackPropAccumulator but optimized
for things like updating embeddings from within a while loop.
Args:
value: The partial gradients represented as an IndexedSlices.
grad: The partial gradients represented as an IndexedSlices.
Returns:
The accumulated IndexedSlices gradient of the loop invariant.
"""
values = value.values
indices = value.indices
values = grad.values
indices = grad.indices
self.Exit()
shape = tensor_shape.TensorShape([tensor_shape.Dimension(1)] +
......@@ -1670,6 +1690,7 @@ class WhileContext(ControlFlowContext):
values_acc._shape = shape # pylint: disable=protected-access
indices_acc = constant_op.constant([0], indices.dtype)
if self.outer_context: self.outer_context.Exit()
self.Enter()
self.AddName(values_acc.name)
self.AddName(indices_acc.name)
......@@ -1687,10 +1708,10 @@ class WhileContext(ControlFlowContext):
for xm, xn in zip(merge_acc, next_acc):
xm.op._update_input(1, xn) # pylint: disable=protected-access
acc_result = [exit(x[0], name="b_acc") for x in switch_acc]
self.ExitResult(acc_result)
return ops.IndexedSlices(values=acc_result[1], indices=acc_result[0],
dense_shape=self.ExitResult(value.dense_shape))
acc_exits = [exit(x[0], name="b_acc") for x in switch_acc]
self.ExitResult(acc_exits)
return ops.IndexedSlices(values=acc_exits[1], indices=acc_exits[0],
dense_shape=grad.dense_shape)
def _InitializeValues(self, values):
self._values = set()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册