diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 936bc77d491eedbd3a0cb289319d9cda3b647cd2..ef06e7e0b5b7a777765ef6dfc843cbc7ead17bfb 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -858,7 +858,10 @@ class Model(object): raise TypeError( "'inputs' must be list or dict, and couldn't be None.") elif inputs: - self._shapes = [list(input.shape) for input in inputs] + if isinstance(inputs, list): + self._shapes = [list(input.shape) for input in inputs] + elif isinstance(inputs, dict): + self._shapes = [list(inputs[name]) for name in inputs] self._inputs = self._verify_spec(inputs, is_input=True) self._labels = self._verify_spec(labels)