提交 4c5dead6 编写于 作者: L LiuChiaChi

remove input requirment in dygraph Model

上级 77a36f89
...@@ -792,14 +792,15 @@ class Model(object): ...@@ -792,14 +792,15 @@ class Model(object):
switched by `paddle.disable_static()`. The usage is as follows. switched by `paddle.disable_static()`. The usage is as follows.
But note, the switching between dynamic and static should be before But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, paddle.static.InputSpec, instantiating a Model. The input description, i.e, paddle.static.InputSpec,
must be required. must be required for static graph.
Args: Args:
network (paddle.nn.Layer): The network is an instance of network (paddle.nn.Layer): The network is an instance of
paddle.nn.Layer. paddle.nn.Layer.
inputs (InputSpec|list|dict|None): `inputs`, entry points of network, inputs (InputSpec|list|dict|None): `inputs`, entry points of network,
could be a InputSpec instance, or lits of InputSpec instances, could be a InputSpec instance, or lits of InputSpec instances,
or dict ({name: InputSpec}), and it couldn't be None. or dict ({name: InputSpec}), and it couldn't be None in static
graph.
labels (InputSpec|list|None): `labels`, entry points of network, labels (InputSpec|list|None): `labels`, entry points of network,
could be a InputSpec instnace or lits of InputSpec instances, could be a InputSpec instnace or lits of InputSpec instances,
or None. For static graph, if labels is required in loss, or None. For static graph, if labels is required in loss,
...@@ -848,6 +849,7 @@ class Model(object): ...@@ -848,6 +849,7 @@ class Model(object):
self._optimizer = None self._optimizer = None
self._test_dataloader = None self._test_dataloader = None
if not in_dygraph_mode():
if not isinstance(inputs, (list, dict, Input)): if not isinstance(inputs, (list, dict, Input)):
raise TypeError( raise TypeError(
"'inputs' must be list or dict, and couldn't be None.") "'inputs' must be list or dict, and couldn't be None.")
...@@ -1863,7 +1865,18 @@ class Model(object): ...@@ -1863,7 +1865,18 @@ class Model(object):
def _verify_spec(self, specs, is_input=False): def _verify_spec(self, specs, is_input=False):
out_specs = [] out_specs = []
if isinstance(specs, dict): if specs is None:
# Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`. But how can we know the actual shape of each input tensor?
if is_input:
out_specs = [
Input(
name=n, shape=[None])
for n in extract_args(self.network.forward) if n != 'self'
]
else:
out_specs = to_list(specs)
elif isinstance(specs, dict):
assert is_input == False assert is_input == False
out_specs = [specs[n] \ out_specs = [specs[n] \
for n in extract_args(self.network.forward) if n != 'self'] for n in extract_args(self.network.forward) if n != 'self']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册