未验证 提交 5fe44571 编写于 作者: L liym27 提交者: GitHub

[Dynamic-to-Static] Support **kwargs as input of the function which is...

[Dynamic-to-Static] Support **kwargs as input of the function which is decorated by `jit.save.to_static` (#29098)
上级 71815637
......@@ -91,6 +91,19 @@ class FunctionSpec(object):
return tuple(args), kwargs
def _replace_value_with_input_spec(self, args):
args_with_spec = []
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
elif isinstance(input_var, core.VarBase):
input_var = paddle.static.InputSpec.from_tensor(input_var)
args_with_spec.append(input_var)
args_with_spec = pack_sequence_as(args, args_with_spec)
return args_with_spec
def args_to_input_spec(self, args, kwargs):
"""
Converts input arguments into InputSpec.
......@@ -103,10 +116,11 @@ class FunctionSpec(object):
kwargs(dict): kwargs arguments received by **kwargs.
Return:
Same nest structure with args by replacing value with InputSpec.
Same nest structure with args and kwargs by replacing value with InputSpec.
"""
input_with_spec = []
args_with_spec = []
kwargs_with_spec = []
if self._input_spec is not None:
# Note: Because the value type and length of `kwargs` is uncertain.
# So we don't support to deal this case while specificing `input_spec` currently.
......@@ -124,24 +138,17 @@ class FunctionSpec(object):
format(len(args), len(self._input_spec)))
# replace argument with corresponding InputSpec.
input_with_spec = convert_to_input_spec(args, self._input_spec)
args_with_spec = convert_to_input_spec(args, self._input_spec)
else:
for idx, input_var in enumerate(flatten(args)):
if isinstance(input_var, np.ndarray):
input_var = paddle.static.InputSpec.from_numpy(input_var)
elif isinstance(input_var, core.VarBase):
input_var = paddle.static.InputSpec.from_tensor(input_var)
input_with_spec.append(input_var)
input_with_spec = pack_sequence_as(args, input_with_spec)
args_with_spec = self._replace_value_with_input_spec(args)
kwargs_with_spec = self._replace_value_with_input_spec(kwargs)
# If without specificing name in input_spec, add default name
# according to argument name from decorated function.
input_with_spec = replace_spec_empty_name(self._arg_names,
input_with_spec)
args_with_spec = replace_spec_empty_name(self._arg_names,
args_with_spec)
return input_with_spec
return args_with_spec, kwargs_with_spec
@switch_to_static_graph
def to_static_inputs_with_spec(self, input_with_spec, main_program):
......
......@@ -146,19 +146,25 @@ class CacheKey(object):
Cached key for ProgramCache.
"""
__slots__ = ['function_spec', 'input_with_spec', 'class_instance']
__slots__ = [
'function_spec', 'input_args_with_spec', 'input_kwargs_with_spec',
'class_instance'
]
def __init__(self, function_spec, input_with_spec, class_instance):
def __init__(self, function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance):
"""
Initializes a cache key.
Args:
functions_spec(FunctionSpec): a FunctionSpec instance of decorated function.
input_with_spec(list[InputSpec]): actual inputs with some arguments replaced by InputSpec.
input_args_with_spec(list[InputSpec]): actual input args with some arguments replaced by InputSpec.
input_kwargs_with_spec(list[{string:InputSpec}]): actual input kwargs with some arguments replaced by InputSpec.
class_instance(object): a instance of class `Layer`.
"""
self.function_spec = function_spec
self.input_with_spec = input_with_spec
self.input_args_with_spec = input_args_with_spec
self.input_kwargs_with_spec = input_kwargs_with_spec
self.class_instance = class_instance
@classmethod
......@@ -177,15 +183,18 @@ class CacheKey(object):
args = args[1:]
# 2. convert tensor and numpy array into InputSpec
_args, _kwargs = function_spec.unified_args_and_kwargs(args, kwargs)
input_with_spec = function_spec.args_to_input_spec(_args, _kwargs)
input_args_with_spec, input_kwargs_with_spec = function_spec.args_to_input_spec(
_args, _kwargs)
# 3. check whether hit the cache or build a new program for the input arguments
return CacheKey(function_spec, input_with_spec, class_instance)
return CacheKey(function_spec, input_args_with_spec,
input_kwargs_with_spec, class_instance)
def __hash__(self):
error_msg = "Arguments to a `@paddle.jit.to_static` must be a hashable Python objects (or nested structures of these types)."
return hash((id(self.function_spec),
make_hashable(self.input_with_spec, error_msg),
make_hashable(self.input_args_with_spec, error_msg),
make_hashable(self.input_kwargs_with_spec, error_msg),
self.class_instance))
def __eq__(self, other):
......@@ -195,8 +204,9 @@ class CacheKey(object):
return not self == other
def __repr__(self):
return "id(function_spec): {}, input_with_spec: {}, class_instance: {}".format(
id(self.function_spec), self.input_with_spec, self.class_instance)
return "id(function_spec): {}, input_args_with_spec: {}, input_kwargs_with_spec: {}, class_instance: {}".format(
id(self.function_spec), self.input_args_with_spec,
self.input_kwargs_with_spec, self.class_instance)
def unwrap_decorators(func):
......@@ -380,11 +390,12 @@ class StaticFunction(object):
if len(args) != len(self._function_spec.args_name):
args, kwargs = self._function_spec.unified_args_and_kwargs(args,
kwargs)
input_with_spec = self._function_spec.args_to_input_spec(args, kwargs)
input_args_with_spec, input_kwargs_with_spec = self._function_spec.args_to_input_spec(
args, kwargs)
# 2. generate cache key
cache_key = CacheKey(self._function_spec, input_with_spec,
self._class_instance)
cache_key = CacheKey(self._function_spec, input_args_with_spec,
input_kwargs_with_spec, self._class_instance)
# 3. check whether hit the cache or build a new program for the input arguments
concrete_program, partial_program_layer = self._program_cache[cache_key]
......@@ -564,7 +575,8 @@ class ConcreteProgram(object):
@staticmethod
@switch_to_static_graph
def from_func_spec(func_spec, input_spec, class_instance):
def from_func_spec(func_spec, input_spec, input_kwargs_spec,
class_instance):
"""
Builds the main_program with specialized inputs and returns outputs
of program as fetch_list.
......@@ -593,6 +605,8 @@ class ConcreteProgram(object):
# 1. Adds `fluid.data` layers for input if needed
inputs = func_spec.to_static_inputs_with_spec(input_spec,
main_program)
kwargs = func_spec.to_static_inputs_with_spec(input_kwargs_spec,
main_program)
if class_instance:
inputs = tuple([class_instance] + list(inputs))
......@@ -605,6 +619,9 @@ class ConcreteProgram(object):
class_instance, False)), param_guard(
get_buffers(class_instance, False)):
try:
if kwargs:
outputs = static_func(*inputs, **kwargs)
else:
outputs = static_func(*inputs)
except BaseException as e:
# NOTE: If e is raised in compile time, e should be attached to ERROR_DATA here.
......@@ -653,7 +670,8 @@ class ProgramCache(object):
def _build_once(self, cache_key):
concrete_program = ConcreteProgram.from_func_spec(
func_spec=cache_key.function_spec,
input_spec=cache_key.input_with_spec,
input_spec=cache_key.input_args_with_spec,
input_kwargs_spec=cache_key.input_kwargs_with_spec,
class_instance=cache_key.class_instance)
return concrete_program, partial_program_from(concrete_program)
......
......@@ -264,6 +264,7 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
concrete_program_5 = foo.get_concrete_program(InputSpec([10]))
# 6. specific unknown kwargs `e`=4
with self.assertRaises(TypeError):
concrete_program_5 = foo.get_concrete_program(
InputSpec([10]), InputSpec([10]), e=4)
......
......@@ -203,5 +203,43 @@ class TestDictPop2(TestDictPop):
self.dygraph_func = test_dic_pop_2
class NetWithDictPop(paddle.nn.Layer):
def __init__(self):
super(NetWithDictPop, self).__init__()
@to_static
def forward(self, x, **kwargs):
x = paddle.to_tensor(x)
y = kwargs.pop('y', None)
if y is True:
y = paddle.to_tensor(x)
x += y
x.mean()
return x
class TestDictPop(TestNetWithDict):
def setUp(self):
self.x = np.array([2, 2]).astype('float32')
def train(self, to_static=False):
prog_trans = ProgramTranslator()
prog_trans.enable(to_static)
with fluid.dygraph.guard(PLACE):
net = NetWithDictPop()
ret = net(z=0, x=self.x, y=True)
return ret.numpy()
def test_ast_to_func(self):
dygraph_result = self._run_dygraph()
static_result = self._run_static()
self.assertTrue(
(dygraph_result == static_result).all(),
msg="dygraph result: {}\nstatic result: {}".format(dygraph_result,
static_result))
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,8 @@ from test_declarative import foo_func
import unittest
paddle.enable_static()
class TestFunctionSpec(unittest.TestCase):
def test_constructor(self):
......@@ -82,8 +84,9 @@ class TestFunctionSpec(unittest.TestCase):
# case 1
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec, b_spec])
input_with_spec = foo_spec.args_to_input_spec(
input_with_spec, _ = foo_spec.args_to_input_spec(
(a_tensor, b_tensor, 1, 2), {})
self.assertTrue(len(input_with_spec) == 4)
self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTrue(input_with_spec[1] == b_spec) # b
......@@ -92,7 +95,8 @@ class TestFunctionSpec(unittest.TestCase):
# case 2
foo_spec = FunctionSpec(foo_func, input_spec=[a_spec])
input_with_spec = foo_spec.args_to_input_spec((a_tensor, b_tensor), {})
input_with_spec, _ = foo_spec.args_to_input_spec((a_tensor, b_tensor),
{})
self.assertTrue(len(input_with_spec) == 2)
self.assertTrue(input_with_spec[0] == a_spec) # a
self.assertTupleEqual(input_with_spec[1].shape, (4, 10)) # b.shape
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册