diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 53928ebed1b25fffe0cd3b16e260e87bfec2c4ac..f502a1dfc340dedeb94a519ba847a08be47e7d6f 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -792,14 +792,15 @@ class Model(object): switched by `paddle.disable_static()`. The usage is as follows. But note, the switching between dynamic and static should be before instantiating a Model. The input description, i.e, paddle.static.InputSpec, - must be required. + must be required for static graph. Args: network (paddle.nn.Layer): The network is an instance of paddle.nn.Layer. inputs (InputSpec|list|dict|None): `inputs`, entry points of network, could be a InputSpec instance, or lits of InputSpec instances, - or dict ({name: InputSpec}), and it couldn't be None. + or dict ({name: InputSpec}), and it couldn't be None in static + graph. labels (InputSpec|list|None): `labels`, entry points of network, could be a InputSpec instnace or lits of InputSpec instances, or None. For static graph, if labels is required in loss, @@ -848,9 +849,10 @@ class Model(object): self._optimizer = None self._test_dataloader = None - if not isinstance(inputs, (list, dict, Input)): - raise TypeError( - "'inputs' must be list or dict, and couldn't be None.") + if not in_dygraph_mode(): + if not isinstance(inputs, (list, dict, Input)): + raise TypeError( + "'inputs' must be list or dict, and couldn't be None.") self._inputs = self._verify_spec(inputs, True) self._labels = self._verify_spec(labels) @@ -1863,7 +1865,18 @@ class Model(object): def _verify_spec(self, specs, is_input=False): out_specs = [] - if isinstance(specs, dict): + if specs is None: + # Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function + # to generate `Input`. But how can we know the actual shape of each input tensor? + if is_input: + out_specs = [ + Input( + name=n, shape=[None]) + for n in extract_args(self.network.forward) if n != 'self' + ] + else: + out_specs = to_list(specs) + elif isinstance(specs, dict): assert is_input == False out_specs = [specs[n] \ for n in extract_args(self.network.forward) if n != 'self']