未验证 提交 c75bb66a 编写于 作者: A Alexandre Passos 提交者: GitHub

Merge pull request #31898 from mrry/cherrypicks_CPD6K

[r2.0 cherrypick] Fix tf.gradients() performance regression
......@@ -69,7 +69,7 @@ def _MarkReachedOps(from_ops, reached_ops, func_graphs):
def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
xs):
xs_set):
"""Initialize the pending count for ops between two lists of Operations.
'pending_count[op]' indicates the number of backprop inputs
......@@ -83,7 +83,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
these functions if they capture from_ops or any reachable ops. This is
useful if to_ops occur in a function and from_ops are in an outer function
or graph.
xs: list of Tensors.
xs_set: ObjectIdentitySet of Tensors.
Returns:
A tuple containing: (1) the subset of to_ops reachable from from_ops by a
......@@ -113,7 +113,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
between_op_list.append(op)
# Clear the boolean so we won't add the inputs again.
reached_ops.remove(op)
for inp in _NonEagerInputs(op, xs):
for inp in _NonEagerInputs(op, xs_set):
queue.append(inp.op)
# X in between_ops iff X is on a path of zero or more backpropagatable tensors
# between from_ops and to_ops
......@@ -125,7 +125,7 @@ def _PendingCount(to_ops, from_ops, colocate_gradients_with_ops, func_graphs,
# Initialize pending count for between ops.
pending_count = collections.defaultdict(int)
for op in between_op_list:
for x in _NonEagerInputs(op, xs):
for x in _NonEagerInputs(op, xs_set):
if x.op in between_ops:
pending_count[x.op] += 1
......@@ -265,7 +265,7 @@ def _VerifyGeneratedGradients(grads, op):
"inputs %d" % (len(grads), op.node_def, len(op.inputs)))
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
def _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set):
"""The set of ops that terminate the gradient computation.
This computes the frontier of the forward graph *before* which backprop
......@@ -281,7 +281,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
from_ops: list of Operations.
stop_gradient_ops: list of Operations never to backprop through.
pending_count: mapping from operation to number of backprop inputs.
xs: list of Tensors.
xs_set: ObjectIdentitySet of Tensors.
Returns:
The set of operations.
......@@ -289,7 +289,7 @@ def _StopOps(from_ops, stop_gradient_ops, pending_count, xs):
stop_ops = set()
for op in from_ops:
is_stop_op = True
for inp in _NonEagerInputs(op, xs):
for inp in _NonEagerInputs(op, xs_set):
if pending_count[inp.op] > 0:
is_stop_op = False
break
......@@ -369,7 +369,7 @@ def _MaybeCompile(scope, op, func, grad_fn):
return grad_fn()
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set):
"""Raises an error if we backprop through a loop var."""
# Find the nearest 'to_op' reachable from 'op' to provide a more helpful error
# message.
......@@ -383,7 +383,7 @@ def _RaiseNoGradWrtInitialLoopValError(op, from_ops, xs):
if curr_op in from_ops:
target_op = curr_op
break
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs))
queue.extend(t.op for t in _NonEagerInputs(curr_op, xs_set))
assert target_op
raise ValueError(
"Cannot compute gradient inside while loop with respect to op '%s'. "
......@@ -425,7 +425,7 @@ def _MaybeCaptured(t):
return t
def _NonEagerInputs(op, xs):
def _NonEagerInputs(op, xs_set):
"""Returns the inputs of op, crossing closure boundaries where necessary.
Does not return any captured EagerTensors, i.e., the number of tensors
......@@ -433,29 +433,28 @@ def _NonEagerInputs(op, xs):
Args:
op: Operation
xs: list of Tensors we are differentiating w.r.t.
xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
Returns:
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
is in a FuncGraph and has captured inputs.
"""
return [t for t in _Inputs(op, xs) if not isinstance(t, ops.EagerTensor)]
return [t for t in _Inputs(op, xs_set) if not isinstance(t, ops.EagerTensor)]
# TODO(skyewm): plumbing xs through everywhere is ugly, consider making
# _GradientsHelper a class with xs as a member variable.
def _Inputs(op, xs):
def _Inputs(op, xs_set):
"""Returns the inputs of op, crossing closure boundaries where necessary.
Args:
op: Operation
xs: list of Tensors we are differentiating w.r.t.
xs_set: ObjectIdentitySet of Tensors we are differentiating w.r.t.
Returns:
A list of tensors. The tensors may be from multiple Graph/FuncGraphs if op
is in a FuncGraph and has captured inputs.
"""
tensors = object_identity.ObjectIdentitySet(xs)
if _IsFunction(op.graph): # pylint: disable=protected-access
inputs = []
for t in op.inputs:
......@@ -464,7 +463,7 @@ def _Inputs(op, xs):
# even if it's a function input for a captured value, whereas usually we'd
# like to traverse through these closures as if the captured value was the
# direct input to op.
if t not in tensors:
if t not in xs_set:
t = _MaybeCaptured(t)
inputs.append(t)
return inputs
......@@ -546,6 +545,7 @@ def _GradientsHelper(ys,
]
xs = ops.internal_convert_n_to_tensor_or_indexed_slices(
xs, name="x", as_ref=True)
xs_set = object_identity.ObjectIdentitySet(xs)
grad_ys = _DefaultGradYs(grad_ys, ys, colocate_gradients_with_ops,
gradient_uid)
......@@ -562,7 +562,7 @@ def _GradientsHelper(ys,
from_ops = [t.op for t in xs]
stop_gradient_ops = [t.op for t in stop_gradients]
reachable_to_ops, pending_count, loop_state = _PendingCount(
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs)
to_ops, from_ops, colocate_gradients_with_ops, func_graphs, xs_set)
# Iterate over the collected ops.
#
......@@ -596,7 +596,7 @@ def _GradientsHelper(ys,
_SetGrad(grads, y, loop_state.ZerosLikeForExit(y))
queue.append(y.op)
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs)
stop_ops = _StopOps(from_ops, stop_gradient_ops, pending_count, xs_set)
while queue:
# generate gradient subgraph for op.
op = queue.popleft()
......@@ -649,7 +649,7 @@ def _GradientsHelper(ys,
op._control_flow_context.IsWhileContext() and
op._control_flow_context ==
ops.get_default_graph()._get_control_flow_context()):
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs)
_RaiseNoGradWrtInitialLoopValError(op, from_ops, xs_set)
# pylint: enable=protected-access
if (grad_fn or is_func_call) and has_out_grads:
......@@ -696,10 +696,10 @@ def _GradientsHelper(ys,
else:
# If no grad_fn is defined or none of out_grads is available,
# just propagate a list of None backwards.
in_grads = [None] * len(_Inputs(op, xs))
in_grads = [None] * len(_Inputs(op, xs_set))
# Note: we don't filter out eager inputs here because the inputs need to
# line up with in_grads.
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs), in_grads)):
for i, (t_in, in_grad) in enumerate(zip(_Inputs(op, xs_set), in_grads)):
if in_grad is not None:
if (isinstance(in_grad, ops.Tensor) and
t_in.dtype != dtypes.resource):
......@@ -719,7 +719,7 @@ def _GradientsHelper(ys,
# Update pending count for the inputs of op and enqueue ready ops.
_UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
xs)
xs_set)
if loop_state:
loop_state.PostProcessing()
......@@ -739,9 +739,9 @@ def _HasAnyNotNoneGrads(grads, op):
def _UpdatePendingAndEnqueueReady(grads, op, queue, pending_count, loop_state,
xs):
xs_set):
"""Update pending count for the inputs of op and enqueue ready ops."""
for x in _NonEagerInputs(op, xs):
for x in _NonEagerInputs(op, xs_set):
pending_count[x.op] -= 1
ready = (pending_count[x.op] == 0)
if loop_state and not ready:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册