From 4c5dead62b4d8465a5f5477c571c2965e5ab1e2a Mon Sep 17 00:00:00 2001 From: LiuChiaChi <709153940@qq.com> Date: Fri, 25 Sep 2020 05:43:25 +0000 Subject: [PATCH] remove input requirment in dygraph Model --- python/paddle/hapi/model.py | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 53928ebed1b..f502a1dfc34 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -792,14 +792,15 @@ class Model(object): switched by `paddle.disable_static()`. The usage is as follows. But note, the switching between dynamic and static should be before instantiating a Model. The input description, i.e, paddle.static.InputSpec, - must be required. + must be required for static graph. Args: network (paddle.nn.Layer): The network is an instance of paddle.nn.Layer. inputs (InputSpec|list|dict|None): `inputs`, entry points of network, could be a InputSpec instance, or lits of InputSpec instances, - or dict ({name: InputSpec}), and it couldn't be None. + or dict ({name: InputSpec}), and it couldn't be None in static + graph. labels (InputSpec|list|None): `labels`, entry points of network, could be a InputSpec instnace or lits of InputSpec instances, or None. For static graph, if labels is required in loss, @@ -848,9 +849,10 @@ class Model(object): self._optimizer = None self._test_dataloader = None - if not isinstance(inputs, (list, dict, Input)): - raise TypeError( - "'inputs' must be list or dict, and couldn't be None.") + if not in_dygraph_mode(): + if not isinstance(inputs, (list, dict, Input)): + raise TypeError( + "'inputs' must be list or dict, and couldn't be None.") self._inputs = self._verify_spec(inputs, True) self._labels = self._verify_spec(labels) @@ -1863,7 +1865,18 @@ class Model(object): def _verify_spec(self, specs, is_input=False): out_specs = [] - if isinstance(specs, dict): + if specs is None: + # Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function + # to generate `Input`. But how can we know the actual shape of each input tensor? + if is_input: + out_specs = [ + Input( + name=n, shape=[None]) + for n in extract_args(self.network.forward) if n != 'self' + ] + else: + out_specs = to_list(specs) + elif isinstance(specs, dict): assert is_input == False out_specs = [specs[n] \ for n in extract_args(self.network.forward) if n != 'self'] -- GitLab