提交 ab3119e5 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Fix bug concerning computation of gradients inside a CondContext.

Change: 117260468
上级 b27022dd
......@@ -142,7 +142,7 @@ def _ExitGrad(_, grad):
# pylint: enable=protected-access
if not grad_ctxt.back_prop:
# The flag `back_prop` is set by users to suppress gradient
# computation for this loop. If the flag `back_prop` is true,
# computation for this loop. If the attribute `back_prop` is false,
# no gradient computation.
return None
grad_ctxt.AddName(grad.name)
......@@ -184,7 +184,7 @@ def _EnterGrad(op, grad):
grad_ctxt = graph._get_control_flow_context()
# pylint: enable=protected-access
if not grad_ctxt.back_prop:
# If the flag `back_prop` is true, no gradient computation.
# If the attribute `back_prop` is true, no gradient computation.
return grad
if op.get_attr("is_constant"):
# Add a gradient accumulator for each loop invariant.
......
......@@ -909,6 +909,14 @@ class ControlFlowContext(object):
"""Return the context containing this context."""
return self._outer_context
@property
def grad_state(self):
raise NotImplementedError("Abstract method")
@property
def back_prop(self):
raise NotImplementedError("Abstract method")
def AddName(self, name):
self._values.add(name)
......@@ -979,6 +987,18 @@ class CondContext(ControlFlowContext):
def branch(self):
return self._branch
@property
def grad_state(self):
if self.GetWhileContext():
return self.GetWhileContext().grad_state
return None
@property
def back_prop(self):
if self.GetWhileContext():
self.GetWhileContext().back_prop
return False
def AddValue(self, val):
"""Add `val` to the current context and its outer context recursively."""
result = val
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册