提交 51f8f449 编写于 作者: V Vojtech Bardiovsky 提交者: TensorFlower Gardener

Fix bug in partial & tf.function.

This was caused by wrong alignment of arguments and default values.

Also add one more assert to improve debuggability.

PiperOrigin-RevId: 251580104
上级 868f22ad
......@@ -265,6 +265,30 @@ class DefFunctionTest(test.TestCase):
functools.partial(f, constant_op.constant(1)))
self.assertAllEqual(func(5), 6)
def test_complicated_partial_with_defaults(self):
def identity(*args):
return args
def dynamic_unroll(core_fn,
input_sequence,
initial_state,
sequence_length=None,
parallel_iterations=1,
swap_memory=False):
del core_fn
self.assertIs(None, sequence_length)
self.assertEqual(1, parallel_iterations)
self.assertTrue(swap_memory)
return input_sequence, initial_state
input_sequence = random_ops.random_uniform([1, 1, 1])
initial_state = random_ops.random_uniform([1, 1])
func = def_function.function(
functools.partial(dynamic_unroll, identity, swap_memory=True))
func(input_sequence, initial_state)
def test_unspecified_default_argument(self):
wrapped = def_function.function(
lambda x, y=2: x + y,
......
......@@ -1004,18 +1004,49 @@ class FunctionSpec(object):
new_defaults = fullargspec.defaults
new_args = fullargspec.args
if fullargspec.defaults:
num_defaults = len(fullargspec.defaults)
args_with_default = fullargspec.args[-num_defaults:]
# To be able to canonicalize the function properly, we want to ignore
# default values that are overridden via a partial kwarg. For example:
#
# def func(a, b, c, d=5, e=7):
# return a, b, c, d, e
# p_func = functools.partial(tf.function(func, 10, e=9))
#
# Here we want to drop from the defaults the parameter `e`. If we
# forwarded the call to the partial function with a default for `e`
# we would get an error for passing two values for one parameter.
#
# Note that this has a limitation: we can only override parameters at
# the end of the parameter list.
#
# In this case we want to end up with 3 arguments (b, c, d) and 1
# default value (5). We do this by constructing a mask where 0 stands
# for a value that was overridden by a partial kwarg. The seemingly
# complicated logic below does just that - for arguments (b, c, d, e)
# we would get a mask (1, 1, 1, 0).
old_args = fullargspec.args
old_defaults = fullargspec.defaults
no_default = object()
num_args_without_defaults = len(old_args) - len(old_defaults)
left_padding = tuple([no_default] * num_args_without_defaults)
args_with_defaults = zip(old_args, left_padding + old_defaults)
# Create a mask where 0 stands for args that had a partial kwarg
# defined.
non_keyword_defaults_mask = [
0 if key in unwrapped.keywords else 1 for key in args_with_default
0 if key in unwrapped.keywords else 1 for key in old_args
]
# Keep only arguments and defaults that were not kwargs of partial.
new_defaults = tuple(
itertools.compress(fullargspec.defaults,
non_keyword_defaults_mask))
new_args = list(
itertools.compress(fullargspec.args, non_keyword_defaults_mask))
new_args_with_defaults = list(
itertools.compress(args_with_defaults, non_keyword_defaults_mask))
# Keep all args.
new_args = [arg for arg, _ in new_args_with_defaults]
# Keep only real default values.
new_defaults = [
default for _, default in new_args_with_defaults
if default is not no_default
]
fullargspec = tf_inspect.FullArgSpec(
args=new_args,
varargs=fullargspec.varargs,
......@@ -1136,7 +1167,12 @@ class FunctionSpec(object):
if not kwargs:
inputs = args
for index in sorted(self._arg_indices_to_default_values.keys()):
default_keys = sorted(self._arg_indices_to_default_values.keys())
if default_keys:
assert min(default_keys) <= len(
args), "Not enough arguments (%s, %s, %s)" % (args, default_keys,
self.arg_names)
for index in default_keys:
if index >= len(args):
inputs += (self._arg_indices_to_default_values[index],)
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册