提交 e42b937e 编写于 作者: D Dan Moldovan 提交者: TensorFlower Gardener

Enable support for loop directives, include shape_invariants.

PiperOrigin-RevId: 285966483
Change-Id: I3eae0b134cd2e954bfa0ac31e6a7411b3a5bb7df
上级 1768c8f2
......@@ -21,6 +21,7 @@ from __future__ import print_function
import gast
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.lang import directives
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import parser
......@@ -151,6 +152,20 @@ class ControlFlowTransformer(converter.Base):
return node
def _create_loop_options(self, node):
if not anno.hasanno(node, anno.Basic.DIRECTIVES):
return gast.Dict([], [])
loop_directives = anno.getanno(node, anno.Basic.DIRECTIVES)
if directives.set_loop_options not in loop_directives:
return gast.Dict([], [])
opts_dict = loop_directives[directives.set_loop_options]
str_keys, values = zip(*opts_dict.items())
keys = [gast.Str(s) for s in str_keys]
values = list(values) # ast and gast don't play well with tuples.
return gast.Dict(keys, values)
def _create_undefined_assigns(self, undefined_symbols):
assignments = []
for s in undefined_symbols:
......@@ -383,8 +398,7 @@ class ControlFlowTransformer(converter.Base):
composite_symbol_names = tuple(
gast.Str(str(symbol)) for symbol in composite_loop_vars)
# TODO(b/140125096): Populate.
opts = gast.Dict([], [])
opts = self._create_loop_options(node)
# TODO(mdan): Use a single template.
# If the body and test functions took a single tuple for loop_vars, instead
......@@ -507,8 +521,7 @@ class ControlFlowTransformer(converter.Base):
composite_symbol_names = tuple(
gast.Str(str(symbol)) for symbol in composite_loop_vars)
# TODO(b/140125096): Populate.
opts = gast.Dict([], [])
opts = self._create_loop_options(node)
# TODO(mdan): Use a single template.
# If the body and test functions took a single tuple for loop_vars, instead
......
......@@ -26,7 +26,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.util.tf_export import tf_export
from tensorflow.tools.docs.doc_controls import do_not_generate_docs
UNSPECIFIED = object()
......@@ -47,37 +46,53 @@ def set_element_type(entity, dtype, shape=UNSPECIFIED):
del shape
# TODO(b/140125096): Implement.
@do_not_generate_docs
@tf_export('autograph.experimental.set_loop_options')
def set_loop_options(
parallel_iterations=UNSPECIFIED,
back_prop=UNSPECIFIED,
swap_memory=UNSPECIFIED,
maximum_iterations=UNSPECIFIED):
maximum_iterations=UNSPECIFIED,
shape_invariants=UNSPECIFIED):
"""Specifies additional arguments to be passed to the enclosing while_loop.
The parameters apply to and only to the immediately enclosing loop. It only
has effect if the loop is staged as a TF while_loop; otherwise the parameters
have no effect.
Usage example:
Usage:
@tf.function(autograph=True)
def dynamic_rnn(..., parallel_iterations=32):
num_steps = ...
for t in tf.range(num_steps):
tf.autograph.experimental.set_loop_options(
parallel_iterations=parallel_iterations)
...
>>> @tf.function(autograph=True)
... def f():
... n = 0
... for i in tf.range(10):
... tf.autograph.experimental.set_loop_options(maximum_iterations=3)
... n += 1
... return n
>>> @tf.function(autograph=True)
... def f():
... v = tf.constant((0,))
... for i in tf.range(3):
... tf.autograph.experimental.set_loop_options(
... shape_invariants=[(v, tf.TensorShape([None]))]
... )
... v = tf.concat((v, [i]), 0)
... return v
Also see tf.while_loop.
Args:
parallel_iterations: See tf.while_loop.
back_prop: See tf.while_loop.
swap_memory: See tf.while_loop.
maximum_iterations: See tf.while_loop.
parallel_iterations: The maximum number of iterations allowed to run in
parallel at any given time. Note that this does not guarantee parallel
execution.
swap_memory: Whether to store intermediate values needed for
gradients on the CPU instead of GPU.
maximum_iterations: Allows limiting the total number of iterations executed
by the loop.
shape_invariants: Allows controlling the argument with the same name passed
to tf.while_loop. Unlike tf.while_loop, this is a list of
`(tensor, shape)` pairs.
"""
del parallel_iterations
del back_prop
del swap_memory
del maximum_iterations
del shape_invariants
......@@ -125,68 +125,91 @@ def _is_subshape(left, right):
return True
def _verify_single_loop_var(name, check_shape, init_loop_var, first_iter_var):
"""Verifies whether init_loop_var and first_iter_var are consistent."""
if isinstance(init_loop_var, (bool, int, float, str)):
init_loop_var = ops.convert_to_tensor_v2(init_loop_var)
if isinstance(first_iter_var, (bool, int, float, str)):
first_iter_var = ops.convert_to_tensor_v2(first_iter_var)
if (not tensor_util.is_tensor(init_loop_var) or
not tensor_util.is_tensor(first_iter_var)):
# TODO(mdan): Remove these verifications once TF ops can properly report names.
def _verify_single_loop_var(
name, check_shape, init, entry, exit_, shape_invariant):
"""Verifies whether the initial, entry and exit values are consistent."""
if isinstance(init, (bool, int, float, str, np.ndarray)):
init = ops.convert_to_tensor_v2(init)
if isinstance(entry, (bool, int, float, str, np.ndarray)):
entry = ops.convert_to_tensor_v2(entry)
if isinstance(exit_, (bool, int, float, str)):
exit_ = ops.convert_to_tensor_v2(exit_)
if (not tensor_util.is_tensor(entry) or
not tensor_util.is_tensor(exit_)):
return
# TODO(mdan): Properly account for CompositeTensors.
if (not hasattr(init_loop_var, 'dtype') or
not hasattr(first_iter_var, 'dtype')):
if (not hasattr(entry, 'dtype') or
not hasattr(exit_, 'dtype')):
return
if (not hasattr(init_loop_var, 'shape') or
not hasattr(first_iter_var, 'shape')):
if (not hasattr(entry, 'shape') or
not hasattr(exit_, 'shape')):
return
if init_loop_var.dtype != first_iter_var.dtype:
if entry.dtype != exit_.dtype:
raise TypeError(
'"{}" has dtype {} before the loop, but dtype {} after one'
' iteration. TensorFlow control flow requires it stays the'
' same.'.format(
name,
init_loop_var.dtype.name,
first_iter_var.dtype.name,
entry.dtype.name,
exit_.dtype.name,
))
if check_shape:
init_shape = init_loop_var.shape
first_iter_shape = first_iter_var.shape
# TODO(b/135183013): Update needed once we support shape_invariants.
if not _is_subshape(first_iter_shape, init_shape):
raise ValueError(
'"{}" has shape {} before the loop, but shape {} after one'
' iteration. TensorFlow control flow requires it stays the'
' same or be more specific.'.format(name, init_shape,
first_iter_shape))
exit_shape = exit_.shape
if shape_invariant is None:
entry_shape = entry.shape
if not _is_subshape(exit_shape, entry_shape):
raise ValueError(
'"{}" has shape {} before the loop, but shape {} after one'
' iteration. Use tf.autograph.experimental.set_loop_options to set'
' shape invariants.'.format(name, entry_shape, exit_shape))
else:
init_shape = init.shape
if not _is_subshape(init_shape, shape_invariant):
raise ValueError(
'"{}" has shape {} before the loop, which does not conform with'
' the shape invariant {}.'.format(name, init_shape,
shape_invariant))
if not _is_subshape(exit_shape, shape_invariant):
raise ValueError(
'"{}" has shape {} after the loop, which does not conform with'
' the shape invariant {}.'.format(
name, exit_shape, shape_invariant))
def _verify_tf_loop_vars(init_loop_vars,
first_iter_vars,
def _verify_tf_loop_vars(init_vars,
iter_entry_vars,
iter_exit_vars,
symbol_names,
opts,
check_shapes=True):
"""Verifies loop variables for consistency."""
# TODO(b/140125096): Use this.
del opts
if check_shapes and 'shape_invariants' in opts:
shape_invariants = opts['shape_invariants']
else:
shape_invariants = nest.map_structure(lambda _: None, iter_entry_vars)
named_vars = zip(symbol_names, init_loop_vars, first_iter_vars)
for name, init_loop_var, first_iter_var in named_vars:
named_vars = zip(symbol_names, init_vars, iter_entry_vars, iter_exit_vars,
shape_invariants)
for name, init, entry, exit_, invariant in named_vars:
try:
nest.assert_same_structure(
init_loop_var, first_iter_var, expand_composites=True)
nest.assert_same_structure(entry, exit_, expand_composites=True)
except (ValueError, TypeError) as e:
raise TypeError('"{}" does not have the same nested structure after one'
' iteration.\n\n{}'.format(name, e))
if invariant is not None:
try:
nest.assert_same_structure(init, invariant, expand_composites=False)
except (ValueError, TypeError) as e:
raise TypeError('"{}" does not have the same nested structure as its'
' corresponding shape invariant.\n\n{}'.format(name, e))
nest.map_structure(
functools.partial(_verify_single_loop_var, name, check_shapes),
init_loop_var, first_iter_var)
functools.partial(_verify_single_loop_var, name, check_shapes), init,
entry, exit_, invariant)
def _verify_single_cond_var(name, body_var, orelse_var):
......@@ -425,6 +448,8 @@ def _tf_ragged_for_stmt(iter_,
else:
n = iter_.row_lengths()[0]
opts['maximum_iterations'] = n
def while_body(iterate_index, *loop_vars):
"""Main loop body."""
iterate = iter_[iterate_index]
......@@ -566,7 +591,7 @@ def _tf_iterator_for_stmt(itr, extra_test, body, get_state, set_state,
# Note: this verification duplicates that perfrmed in tf_while_stmt,
# but needs to be done earlier to prevent the tf.cond inside while_body
# from blowing up first.
_verify_tf_loop_vars(loop_vars, new_vars,
_verify_tf_loop_vars(init_vars, loop_vars, new_vars,
basic_symbol_names + composite_symbol_names, opts)
return new_vars
......@@ -653,20 +678,26 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
# TODO(mdan): Simplify this - following it is extremely difficult.
init_state = get_state()
aug_init_vars = init_vars, init_state
def scan_body(aug_vars, iterate):
"""The main loop body wrapper. Only calculates the stop condition."""
loop_vars, state = aug_vars
def true_fn():
"""Main path - stop condition is not set."""
set_state(state)
outputs = body(iterate, *loop_vars)
new_vars = body(iterate, *loop_vars)
new_state = get_state()
_verify_tf_loop_vars(
init_vars + init_state,
loop_vars + state,
outputs + state,
new_vars + new_state,
basic_symbol_names + composite_symbol_names,
opts,
check_shapes=False)
return outputs, get_state()
return new_vars, new_state
extra_cond = extra_test(*loop_vars)
new_vars, new_state = control_flow_ops.cond(
......@@ -690,11 +721,9 @@ def _dataset_for_stmt_with_extra_test(ds, extra_test, body, get_state,
del extra_cond
return output_aug_vars, output_state
init_state = get_state()
aug_vars = init_vars, init_state
ds = _general_purpose_scan(ds, aug_vars, scan_body)
ds = _general_purpose_scan(ds, aug_init_vars, scan_body)
ds = ds.apply(take_while_ops.take_while(take_while_predicate))
final_aug_vars = ds.reduce(aug_vars, reduce_body)
final_aug_vars = ds.reduce(aug_init_vars, reduce_body)
final_vars, final_state = final_aug_vars
set_state(final_state)
return final_vars
......@@ -741,6 +770,7 @@ def _dataset_for_stmt_no_extra_test(ds, body, get_state, set_state, init_vars,
new_state = get_state()
_verify_tf_loop_vars(
init_vars + init_state,
loop_vars + state,
new_vars + new_state,
symbol_names,
......@@ -824,11 +854,23 @@ def while_stmt(test,
return _py_while_stmt(test, body, get_state, set_state, init_vars, opts)
def _shape_invariants_mapping_to_positional_list(mapping, keys):
# The keys are not expected to be hashable.
mapping = {id(k): (k, v) for k, v in mapping}
result = []
for k in keys:
map_key, map_val = mapping.get(id(k), (None, None))
result.append(map_val if map_key is k else None)
return tuple(result)
def _tf_while_stmt(test, body, get_state, set_state, init_vars,
basic_symbol_names, composite_symbol_names, opts):
"""Overload of while_stmt that stages a TF while_stmt."""
_disallow_undefs_into_loop(*init_vars)
aug_init_vars = init_vars + get_state()
# TODO(mdan): Simplify this.
loop_vars_slice = slice(len(init_vars))
state_slice = slice(len(init_vars), None)
......@@ -844,7 +886,7 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
set_state(state)
loop_vars = body(*aug_loop_vars[loop_vars_slice])
new_state = loop_vars + get_state()
_verify_tf_loop_vars(aug_loop_vars, new_state,
_verify_tf_loop_vars(aug_init_vars, aug_loop_vars, new_state,
basic_symbol_names + composite_symbol_names, opts)
return new_state
......@@ -853,7 +895,10 @@ def _tf_while_stmt(test, body, get_state, set_state, init_vars,
# This enforces consistency across versions.
opts['return_same_structure'] = True
aug_init_vars = init_vars + get_state()
if 'shape_invariants' in opts:
opts['shape_invariants'] = _shape_invariants_mapping_to_positional_list(
opts['shape_invariants'], aug_init_vars)
final_aug_vars = control_flow_ops.while_loop(aug_test, aug_body,
aug_init_vars, **opts)
final_state = final_aug_vars[state_slice]
......
......@@ -503,13 +503,16 @@ def _shape_invariant_to_type_spec(var, shape):
Returns:
A `TypeSpec` for `var`, consistent with the given shape.
"""
if isinstance(shape, type_spec.TypeSpec):
if shape is None:
return type_spec.type_spec_from_value(var)
elif isinstance(shape, type_spec.TypeSpec):
if not shape.is_compatible_with(var):
raise TypeError("TypeSpec %r is not compatible with %r" % (shape, var))
return shape
elif not isinstance(shape, tensor_shape.TensorShape):
raise TypeError("Expected shape to be a TypeSpec or TensorShape, got %r"
% shape)
raise TypeError(
"Expected shape to be a TypeSpec, TensorShape or None, got %r for"
" value %r" % (shape, var))
if isinstance(var, ops.Tensor):
return tensor_spec.TensorSpec(shape, var.dtype)
......
......@@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "set_loop_options"
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
}
}
......@@ -10,6 +10,6 @@ tf_module {
}
member_method {
name: "set_loop_options"
argspec: "args=[\'parallel_iterations\', \'back_prop\', \'swap_memory\', \'maximum_iterations\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
argspec: "args=[\'parallel_iterations\', \'swap_memory\', \'maximum_iterations\', \'shape_invariants\'], varargs=None, keywords=None, defaults=[\'<object object instance>\', \'<object object instance>\', \'<object object instance>\', \'<object object instance>\'], "
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册