未验证 提交 252aeb1a 编写于 作者: A Aurelius84 提交者: GitHub

[Dy2stat]Add naming rule if not specific InputSpec.name (#26997)

* Add naming rule if not specific InputSpec.name

* fix function name typo

* refine comment

* remove print statement
上级 ed292695
......@@ -135,6 +135,11 @@ class FunctionSpec(object):
input_with_spec = pack_sequence_as(args, input_with_spec)
# If without specificing name in input_spec, add default name
# according to argument name from decorated function.
input_with_spec = replace_spec_empty_name(self._arg_names,
input_with_spec)
return input_with_spec
@switch_to_static_graph
......@@ -309,3 +314,61 @@ def convert_to_input_spec(inputs, input_spec):
raise TypeError(
"The type(input_spec) should be a `InputSpec` or dict/list/tuple of it, but received {}.".
type_name(input_spec))
def replace_spec_empty_name(args_name, input_with_spec):
"""
Adds default name according to argument name from decorated function
if without specificing InputSpec.name
The naming rule are as followed:
1. If InputSpec.name is not None, do nothing.
2. If each argument `x` corresponds to an InputSpec, using the argument name like `x`
3. If the arguments `inputs` corresponds to a list(InputSpec), using name like `inputs_0`, `inputs_1`
4. If the arguments `input_dic` corresponds to a dict(InputSpec), using key as name.
For example:
# case 1: foo(x, y)
foo = to_static(foo, input_spec=[InputSpec([None, 10]), InputSpec([None])])
print([in_var.name for in_var in foo.inputs]) # [x, y]
# case 2: foo(inputs) where inputs is a list
foo = to_static(foo, input_spec=[[InputSpec([None, 10]), InputSpec([None])]])
print([in_var.name for in_var in foo.inputs]) # [inputs_0, inputs_1]
# case 3: foo(inputs) where inputs is a dict
foo = to_static(foo, input_spec=[{'x': InputSpec([None, 10]), 'y': InputSpec([None])}])
print([in_var.name for in_var in foo.inputs]) # [x, y]
"""
input_with_spec = list(input_with_spec)
candidate_arg_names = args_name[:len(input_with_spec)]
for i, arg_name in enumerate(candidate_arg_names):
input_spec = input_with_spec[i]
input_with_spec[i] = _replace_spec_name(arg_name, input_spec)
return input_with_spec
def _replace_spec_name(name, input_spec):
"""
Replaces InputSpec.name with given `name` while not specificing it.
"""
if isinstance(input_spec, paddle.static.InputSpec):
if input_spec.name is None:
input_spec.name = name
return input_spec
elif isinstance(input_spec, (list, tuple)):
processed_specs = []
for i, spec in enumerate(input_spec):
new_name = "{}_{}".format(name, i)
processed_specs.append(_replace_spec_name(new_name, spec))
return processed_specs
elif isinstance(input_spec, dict):
processed_specs = {}
for key, spec in six.iteritems(input_spec):
processed_specs[key] = _replace_spec_name(key, spec)
return processed_specs
else:
return input_spec
......@@ -47,8 +47,8 @@ class SimpleNet(Layer):
return z
@declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]])
def func_with_list(self, l):
x, y, int_val = l
def func_with_list(self, l, int_val=1):
x, y = l
z = x + y
z = z + int_val
return z
......@@ -60,10 +60,7 @@ class SimpleNet(Layer):
def func_with_dict(self, d):
x = d['x']
y = d['y']
int_val = d['int_val']
z = x + y
z = z + int_val
return z
......@@ -131,10 +128,10 @@ class TestInputSpec(unittest.TestCase):
self.assertTrue(len(net.add_func.program_cache) == 1)
# 5. test input with list
out = net.func_with_list([x, y, int_val])
out = net.func_with_list([x, y], int_val)
# 6. test input with dict
out = net.func_with_dict({'x': x, 'y': y, 'int_val': int_val})
out = net.func_with_dict({'x': x, 'y': y})
# 7. test input with lits contains dict
int_np = np.ones([1]).astype('float32')
......@@ -293,6 +290,30 @@ class TestDifferentInputSpecCacheProgram(unittest.TestCase):
foo_3.concrete_program
class TestInputDefaultName(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.net = SimpleNet()
def assert_default_name(self, func_name, input_names):
decorated_func = getattr(self.net, func_name)
spec_names = [x.name for x in decorated_func.inputs]
self.assertListEqual(spec_names, input_names)
def test_common_input(self):
self.assert_default_name('forward', ['x'])
def test_list_input(self):
self.assert_default_name('func_with_list', ['l_0', 'l_1'])
def test_dict_input(self):
self.assert_default_name('func_with_dict', ['x', 'y'])
def test_nest_input(self):
self.assert_default_name('func_with_list_dict', ['dl_0', 'x', 'y'])
class TestDeclarativeAPI(unittest.TestCase):
def test_error(self):
func = declarative(dyfunc_to_variable)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册