未验证 提交 5fe44571 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Support **kwargs as input of the function which is...

[Dynamic-to-Static] Support **kwargs as input of the function which is decorated by `jit.save.to_static` (#29098)
上级 71815637
...@@ -91,6 +91,19 @@ class FunctionSpec(object): ...@@ -91,6 +91,19 @@ class FunctionSpec(object):
return tuple(args), kwargs return tuple(args), kwargs
def _replace_value_with_input_spec(self, args):
args_with_spec = []
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
elif isinstance(input_var, core.VarBase):
input_var = paddle.static.InputSpec.from_tensor(input_var)
args_with_spec.append(input_var)
args_with_spec = pack_sequence_as(args, args_with_spec)
return args_with_spec
def args_to_input_spec(self, args, kwargs): def args_to_input_spec(self, args, kwargs):
""" """
Converts input arguments into InputSpec. Converts input arguments into InputSpec.
...@@ -103,10 +116,11 @@ class FunctionSpec(object): ...@@ -103,10 +116,11 @@ class FunctionSpec(object):
kwargs(dict): kwargs arguments received by **kwargs. kwargs(dict): kwargs arguments received by **kwargs.
Return: Return:
Same nest structure with args by replacing value with InputSpec. Same nest structure with args and kwargs by replacing value with InputSpec.
""" """
input_with_spec = []
args_with_spec = []
kwargs_with_spec = []
if self._input_spec is not None: if self._input_spec is not None:
# Note: Because the value type and length of `kwargs` is uncertain. # Note: Because the value type and length of `kwargs` is uncertain.
# So we don't support to deal this case while specificing `input_spec` currently. # So we don't support to deal this case while specificing `input_spec` currently.
...@@ -124,24 +138,17 @@ class FunctionSpec(object): ...@@ -124,24 +138,17 @@ class FunctionSpec(object):
format(len(args), len(self._input_spec))) format(len(args), len(self._input_spec)))
# replace argument with corresponding InputSpec. # replace argument with corresponding InputSpec.
input_with_spec = convert_to_input_spec(args, self._input_spec) args_with_spec = convert_to_input_spec(args, self._input_spec)
else: else:
for idx, input_var in enumerate(flatten(args)): args_with_spec = self._replace_value_with_input_spec(args)
if isinstance(input_var, np.ndarray): kwargs_with_spec = self._replace_value_with_input_spec(kwargs)
input_var = paddle.static.InputSpec.from_numpy(input_var)
elif isinstance(input_var, core.VarBase):
input_var = paddle.static.InputSpec.from_tensor(input_var)
input_with_spec.append(input_var)
input_with_spec = pack_sequence_as(args, input_with_spec)
# If without specificing name in input_spec, add default name # If without specificing name in input_spec, add default name
# according to argument name from decorated function. # according to argument name from decorated function.
input_with_spec = replace_spec_empty_name(self._arg_names, args_with_spec = replace_spec_empty_name(self._arg_names,
input_with_spec) args_with_spec)
return input_with_spec return args_with_spec, kwargs_with_spec
@switch_to_static_graph @switch_to_static_graph
def to_static_inputs_with_spec(self, input_with_spec, main_program): def to_static_inputs_with_spec(self, input_with_spec, main_program):
......
...@@ -146,19 +146,25 @@ class CacheKey(object): ...@@ -146,19 +146,25 @@ class CacheKey(object):
Cached key for ProgramCache. Cached key for ProgramCache.
""" """
__slots__ = ['function_spec', 'input_with_spec', 'class_instance'] __slots__ = [
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
'class_instance'
]
def __init__(self, function_spec, input_with_spec, class_instance): def __init__(self, function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance):
""" """
Initializes a cache key. Initializes a cache key.
Args: Args:
functions_spec(FunctionSpec): a FunctionSpec instance of decorated function. functions_spec(FunctionSpec): a FunctionSpec instance of decorated function.
input_with_spec(list[InputSpec]): actual inputs with some arguments replaced by InputSpec. input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec.
input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec.
class_instance(object): a instance of class `Layer`. class_instance(object): a instance of class `Layer`.
""" """
self.function_spec = function_spec self.function_spec = function_spec
self.input_with_spec = input_with_spec self.input_args_with_spec = input_args_with_spec
self.input_kwargs_with_spec = input_kwargs_with_spec
self.class_instance = class_instance self.class_instance = class_instance
@classmethod @classmethod
...@@ -177,15 +183,18 @@ class CacheKey(object): ...@@ -177,15 +183,18 @@ class CacheKey(object):
args = args[1:] args = args[1:]
# 2. convert tensor and numpy array into InputSpec # 2. convert tensor and numpy array into InputSpec
_args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs) _args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs)
input_with_spec = function_spec.args_to_input_spec(_args, _kwargs) input_args_with_spec, input_kwargs_with_spec = function_spec.args_to_input_spec(
_args, _kwargs)
# 3. check whether hit the cache or build a new program for the input arguments # 3. check whether hit the cache or build a new program for the input arguments
return CacheKey(function_spec, input_with_spec, class_instance) return CacheKey(function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance)
def __hash__(self): def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)." error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
return hash((id(self.function_spec), return hash((id(self.function_spec),
make_hashable(self.input_with_spec, error_msg), make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self.class_instance)) self.class_instance))
def __eq__(self, other): def __eq__(self, other):
...@@ -195,8 +204,9 @@ class CacheKey(object): ...@@ -195,8 +204,9 @@ class CacheKey(object):
return not self == other return not self == other
def __repr__(self): def __repr__(self):
return "id(function_spec): {}, input_with_spec: {}, class_instance: {}".format( return "id(function_spec): {}, input_args_with_spec: {}, input_kwargs_with_spec: {}, class_instance: {}".format(
id(self.function_spec), self.input_with_spec, self.class_instance) id(self.function_spec), self.input_args_with_spec,
self.input_kwargs_with_spec, self.class_instance)
def unwrap_decorators(func): def unwrap_decorators(func):
...@@ -380,11 +390,12 @@ class StaticFunction(object): ...@@ -380,11 +390,12 @@ class StaticFunction(object):
if len(args) != len(self._function_spec.args_name): if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(args, args, kwargs = self._function_spec.unified_args_and_kwargs(args,
kwargs) kwargs)
input_with_spec = self._function_spec.args_to_input_spec(args, kwargs) input_args_with_spec, input_kwargs_with_spec = self._function_spec.args_to_input_spec(
args, kwargs)
# 2. generate cache key # 2. generate cache key
cache_key = CacheKey(self._function_spec, input_with_spec, cache_key = CacheKey(self._function_spec, input_args_with_spec,
self._class_instance) input_kwargs_with_spec, self._class_instance)
# 3. check whether hit the cache or build a new program for the input arguments # 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key] concrete_program, partial_program_layer = self._program_cache[cache_key]
...@@ -564,7 +575,8 @@ class ConcreteProgram(object): ...@@ -564,7 +575,8 @@ class ConcreteProgram(object):
@staticmethod @staticmethod
@switch_to_static_graph @switch_to_static_graph
def from_func_spec(func_spec, input_spec, class_instance): def from_func_spec(func_spec, input_spec, input_kwargs_spec,
class_instance):
""" """
Builds the main_program with specialized inputs and returns outputs Builds the main_program with specialized inputs and returns outputs
of program as fetch_list. of program as fetch_list.
...@@ -593,6 +605,8 @@ class ConcreteProgram(object): ...@@ -593,6 +605,8 @@ class ConcreteProgram(object):
# 1. Adds `fluid.data` layers for input if needed # 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec, inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program) main_program)
kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec,
main_program)
if class_instance: if class_instance:
inputs = tuple([class_instance] + list(inputs)) inputs = tuple([class_instance] + list(inputs))
...@@ -605,6 +619,9 @@ class ConcreteProgram(object): ...@@ -605,6 +619,9 @@ class ConcreteProgram(object):
class_instance, False)), param_guard( class_instance, False)), param_guard(
get_buffers(class_instance, False)): get_buffers(class_instance, False)):
try: try:
if kwargs:
outputs = static_func(*inputs, **kwargs)
else:
outputs = static_func(*inputs) outputs = static_func(*inputs)
except BaseException as e: except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here. # NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
...@@ -653,7 +670,8 @@ class ProgramCache(object): ...@@ -653,7 +670,8 @@ class ProgramCache(object):
def _build_once(self, cache_key): def _build_once(self, cache_key):
concrete_program = ConcreteProgram.from_func_spec( concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec, func_spec=cache_key.function_spec,
input_spec=cache_key.input_with_spec, input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance) class_instance=cache_key.class_instance)
return concrete_program, partial_program_from(concrete_program) return concrete_program, partial_program_from(concrete_program)
......
...@@ -264,6 +264,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase): ...@@ -264,6 +264,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
concrete_program_5 = foo.get_concrete_program(InputSpec([10])) concrete_program_5 = foo.get_concrete_program(InputSpec([10]))
# 6. specific unknown kwargs `e`=4 # 6. specific unknown kwargs `e`=4
with self.assertRaises(TypeError):
concrete_program_5 = foo.get_concrete_program( concrete_program_5 = foo.get_concrete_program(
InputSpec([10]), InputSpec([10]), e=4) InputSpec([10]), InputSpec([10]), e=4)
......
...@@ -203,5 +203,43 @@ class TestDictPop2(TestDictPop): ...@@ -203,5 +203,43 @@ class TestDictPop2(TestDictPop):
self.dygraph_func = test_dic_pop_2 self.dygraph_func = test_dic_pop_2
class NetWithDictPop(paddle.nn.Layer):
def __init__(self):
super(NetWithDictPop, self).__init__()
@to_static
def forward(self, x, **kwargs):
x = paddle.to_tensor(x)
y = kwargs.pop('y', None)
if y is True:
y = paddle.to_tensor(x)
x += y
x.mean()
return x
class TestDictPop(TestNetWithDict):
def setUp(self):
self.x = np.array([2, 2]).astype('float32')
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
with fluid.dygraph.guard(PLACE):
net = NetWithDictPop()
ret = net(z=0, x=self.x, y=True)
return ret.numpy()
def test_ast_to_func(self):
dygraph_result = self._run_dygraph()
static_result = self._run_static()
self.assertTrue(
(dygraph_result == static_result).all(),
msg="dygraph result: {}\nstatic result: {}".format(dygraph_result,
static_result))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,6 +20,8 @@ from test_declarative import foo_func ...@@ -20,6 +20,8 @@ from test_declarative import foo_func
import unittest import unittest
paddle.enable_static()
class TestFunctionSpec(unittest.TestCase): class TestFunctionSpec(unittest.TestCase):
def test_constructor(self): def test_constructor(self):
...@@ -82,8 +84,9 @@ class TestFunctionSpec(unittest.TestCase): ...@@ -82,8 +84,9 @@ class TestFunctionSpec(unittest.TestCase):
# case 1 # case 1
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec]) foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
input_with_spec = foo_spec.args_to_input_spec( input_with_spec, _ = foo_spec.args_to_input_spec(
(a_tensor, b_tensor, 1, 2), {}) (a_tensor, b_tensor, 1, 2), {})
self.assertTrue(len(input_with_spec) == 4) self.assertTrue(len(input_with_spec) == 4)
self.assertTrue(input_with_spec[0] == a_spec) # a self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTrue(input_with_spec[1] == b_spec) # b self.assertTrue(input_with_spec[1] == b_spec) # b
...@@ -92,7 +95,8 @@ class TestFunctionSpec(unittest.TestCase): ...@@ -92,7 +95,8 @@ class TestFunctionSpec(unittest.TestCase):
# case 2 # case 2
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec]) foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor), {}) input_with_spec, _ = foo_spec.args_to_input_spec((a_tensor, b_tensor),
{})
self.assertTrue(len(input_with_spec) == 2) self.assertTrue(len(input_with_spec) == 2)
self.assertTrue(input_with_spec[0] == a_spec) # a self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册