From 5fe44571f0cd80949bf59287b3008e044d711675 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 27 Nov 2020 14:11:57 +0800 Subject: [PATCH] [Dynamic-to-Static] Support **kwargs as input of the function which is decorated by `jit.save.to_static` (#29098) --- .../dygraph_to_static/function_spec.py | 47 ++++++++++-------- .../dygraph_to_static/program_translator.py | 48 +++++++++++++------ .../dygraph_to_static/test_declarative.py | 5 +- .../unittests/dygraph_to_static/test_dict.py | 38 +++++++++++++++ .../dygraph_to_static/test_function_spec.py | 8 +++- 5 files changed, 107 insertions(+), 39 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py index 3d1ed836ff1..34fb168495a 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/function_spec.py @@ -50,10 +50,10 @@ class FunctionSpec(object): """ Moves kwargs with default value into arguments list to keep `args` contain the same length value as function definition. - - For example: - - Given function definition: `def foo(x, a=1, b=2)`, + + For example: + + Given function definition: `def foo(x, a=1, b=2)`, when calling it by `foo(23)`, the args is `[23]`, kwargs is `{a=1, b=2}`. In this function, it will return args with `[23, 1, 2]`, kwargs with `{}` @@ -91,10 +91,23 @@ class FunctionSpec(object): return tuple(args), kwargs + def _replace_value_with_input_spec(self, args): + args_with_spec = [] + for idx, input_var in enumerate(flatten(args)): + if isinstance(input_var, np.ndarray): + input_var = paddle.static.InputSpec.from_numpy(input_var) + elif isinstance(input_var, core.VarBase): + input_var = paddle.static.InputSpec.from_tensor(input_var) + + args_with_spec.append(input_var) + + args_with_spec = pack_sequence_as(args, args_with_spec) + return args_with_spec + def args_to_input_spec(self, args, kwargs): """ Converts input arguments into InputSpec. - + 1. If specific input_spec, use them to construct feed layers. 2. If input_spec is None, consider all Tensor and Numpy.ndarray as feed layers @@ -103,10 +116,11 @@ class FunctionSpec(object): kwargs(dict): kwargs arguments received by **kwargs. Return: - Same nest structure with args by replacing value with InputSpec. + Same nest structure with args and kwargs by replacing value with InputSpec. """ - input_with_spec = [] + args_with_spec = [] + kwargs_with_spec = [] if self._input_spec is not None: # Note: Because the value type and length of `kwargs` is uncertain. # So we don't support to deal this case while specificing `input_spec` currently. @@ -124,24 +138,17 @@ class FunctionSpec(object): format(len(args), len(self._input_spec))) # replace argument with corresponding InputSpec. - input_with_spec = convert_to_input_spec(args, self._input_spec) + args_with_spec = convert_to_input_spec(args, self._input_spec) else: - for idx, input_var in enumerate(flatten(args)): - if isinstance(input_var, np.ndarray): - input_var = paddle.static.InputSpec.from_numpy(input_var) - elif isinstance(input_var, core.VarBase): - input_var = paddle.static.InputSpec.from_tensor(input_var) - - input_with_spec.append(input_var) - - input_with_spec = pack_sequence_as(args, input_with_spec) + args_with_spec = self._replace_value_with_input_spec(args) + kwargs_with_spec = self._replace_value_with_input_spec(kwargs) # If without specificing name in input_spec, add default name # according to argument name from decorated function. - input_with_spec = replace_spec_empty_name(self._arg_names, - input_with_spec) + args_with_spec = replace_spec_empty_name(self._arg_names, + args_with_spec) - return input_with_spec + return args_with_spec, kwargs_with_spec @switch_to_static_graph def to_static_inputs_with_spec(self, input_with_spec, main_program): diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py index 31ca24e3c12..581eec5cfd3 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py @@ -146,19 +146,25 @@ class CacheKey(object): Cached key for ProgramCache. """ - __slots__ = ['function_spec', 'input_with_spec', 'class_instance'] + __slots__ = [ + 'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec', + 'class_instance' + ] - def __init__(self, function_spec, input_with_spec, class_instance): + def __init__(self, function_spec, input_args_with_spec, + input_kwargs_with_spec, class_instance): """ Initializes a cache key. Args: functions_spec(FunctionSpec): a FunctionSpec instance of decorated function. - input_with_spec(list[InputSpec]): actual inputs with some arguments replaced by InputSpec. + input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec. + input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec. class_instance(object): a instance of class `Layer`. """ self.function_spec = function_spec - self.input_with_spec = input_with_spec + self.input_args_with_spec = input_args_with_spec + self.input_kwargs_with_spec = input_kwargs_with_spec self.class_instance = class_instance @classmethod @@ -177,15 +183,18 @@ class CacheKey(object): args = args[1:] # 2. convert tensor and numpy array into InputSpec _args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs) - input_with_spec = function_spec.args_to_input_spec(_args, _kwargs) + input_args_with_spec, input_kwargs_with_spec = function_spec.args_to_input_spec( + _args, _kwargs) # 3. check whether hit the cache or build a new program for the input arguments - return CacheKey(function_spec, input_with_spec, class_instance) + return CacheKey(function_spec, input_args_with_spec, + input_kwargs_with_spec, class_instance) def __hash__(self): error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." return hash((id(self.function_spec), - make_hashable(self.input_with_spec, error_msg), + make_hashable(self.input_args_with_spec, error_msg), + make_hashable(self.input_kwargs_with_spec, error_msg), self.class_instance)) def __eq__(self, other): @@ -195,8 +204,9 @@ class CacheKey(object): return not self == other def __repr__(self): - return "id(function_spec): {}, input_with_spec: {}, class_instance: {}".format( - id(self.function_spec), self.input_with_spec, self.class_instance) + return "id(function_spec): {}, input_args_with_spec: {}, input_kwargs_with_spec: {}, class_instance: {}".format( + id(self.function_spec), self.input_args_with_spec, + self.input_kwargs_with_spec, self.class_instance) def unwrap_decorators(func): @@ -380,11 +390,12 @@ class StaticFunction(object): if len(args) != len(self._function_spec.args_name): args, kwargs = self._function_spec.unified_args_and_kwargs(args, kwargs) - input_with_spec = self._function_spec.args_to_input_spec(args, kwargs) + input_args_with_spec, input_kwargs_with_spec = self._function_spec.args_to_input_spec( + args, kwargs) # 2. generate cache key - cache_key = CacheKey(self._function_spec, input_with_spec, - self._class_instance) + cache_key = CacheKey(self._function_spec, input_args_with_spec, + input_kwargs_with_spec, self._class_instance) # 3. check whether hit the cache or build a new program for the input arguments concrete_program, partial_program_layer = self._program_cache[cache_key] @@ -564,7 +575,8 @@ class ConcreteProgram(object): @staticmethod @switch_to_static_graph - def from_func_spec(func_spec, input_spec, class_instance): + def from_func_spec(func_spec, input_spec, input_kwargs_spec, + class_instance): """ Builds the main_program with specialized inputs and returns outputs of program as fetch_list. @@ -593,6 +605,8 @@ class ConcreteProgram(object): # 1. Adds `fluid.data` layers for input if needed inputs = func_spec.to_static_inputs_with_spec(input_spec, main_program) + kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec, + main_program) if class_instance: inputs = tuple([class_instance] + list(inputs)) @@ -605,7 +619,10 @@ class ConcreteProgram(object): class_instance, False)), param_guard( get_buffers(class_instance, False)): try: - outputs = static_func(*inputs) + if kwargs: + outputs = static_func(*inputs, **kwargs) + else: + outputs = static_func(*inputs) except BaseException as e: # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. error.attach_error_data(e) @@ -653,7 +670,8 @@ class ProgramCache(object): def _build_once(self, cache_key): concrete_program = ConcreteProgram.from_func_spec( func_spec=cache_key.function_spec, - input_spec=cache_key.input_with_spec, + input_spec=cache_key.input_args_with_spec, + input_kwargs_spec=cache_key.input_kwargs_with_spec, class_instance=cache_key.class_instance) return concrete_program, partial_program_from(concrete_program) diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py index a5c49e4d7d9..91086c31a39 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_declarative.py @@ -264,8 +264,9 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): concrete_program_5 = foo.get_concrete_program(InputSpec([10])) # 6. specific unknown kwargs `e`=4 - concrete_program_5 = foo.get_concrete_program( - InputSpec([10]), InputSpec([10]), e=4) + with self.assertRaises(TypeError): + concrete_program_5 = foo.get_concrete_program( + InputSpec([10]), InputSpec([10]), e=4) def test_concrete_program(self): with fluid.dygraph.guard(fluid.CPUPlace()): diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py index 4af955e774a..d4995a72bc4 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_dict.py @@ -203,5 +203,43 @@ class TestDictPop2(TestDictPop): self.dygraph_func = test_dic_pop_2 +class NetWithDictPop(paddle.nn.Layer): + def __init__(self): + super(NetWithDictPop, self).__init__() + + @to_static + def forward(self, x, **kwargs): + x = paddle.to_tensor(x) + y = kwargs.pop('y', None) + if y is True: + y = paddle.to_tensor(x) + x += y + + x.mean() + return x + + +class TestDictPop(TestNetWithDict): + def setUp(self): + self.x = np.array([2, 2]).astype('float32') + + def train(self, to_static=False): + prog_trans = ProgramTranslator() + prog_trans.enable(to_static) + with fluid.dygraph.guard(PLACE): + net = NetWithDictPop() + ret = net(z=0, x=self.x, y=True) + return ret.numpy() + + def test_ast_to_func(self): + dygraph_result = self._run_dygraph() + static_result = self._run_static() + + self.assertTrue( + (dygraph_result == static_result).all(), + msg="dygraph result: {}\nstatic result: {}".format(dygraph_result, + static_result)) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py index 88697bc1b36..9dc8c12f245 100644 --- a/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py +++ b/python/paddle/fluid/tests/unittests/dygraph_to_static/test_function_spec.py @@ -20,6 +20,8 @@ from test_declarative import foo_func import unittest +paddle.enable_static() + class TestFunctionSpec(unittest.TestCase): def test_constructor(self): @@ -82,8 +84,9 @@ class TestFunctionSpec(unittest.TestCase): # case 1 foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) - input_with_spec = foo_spec.args_to_input_spec( + input_with_spec, _ = foo_spec.args_to_input_spec( (a_tensor, b_tensor, 1, 2), {}) + self.assertTrue(len(input_with_spec) == 4) self.assertTrue(input_with_spec[0] == a_spec) # a self.assertTrue(input_with_spec[1] == b_spec) # b @@ -92,7 +95,8 @@ class TestFunctionSpec(unittest.TestCase): # case 2 foo_spec = FunctionSpec(foo_func, input_spec=[a_spec]) - input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor), {}) + input_with_spec, _ = foo_spec.args_to_input_spec((a_tensor, b_tensor), + {}) self.assertTrue(len(input_with_spec) == 2) self.assertTrue(input_with_spec[0] == a_spec) # a self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape -- GitLab