diff --git a/python/paddle/hapi/model.py b/python/paddle/hapi/model.py index 459d6cd3284e9a5be2a4d9ad08192a6bcc28ef79..21e3054dde7d7f0e196cfd32298d1a719ea0fa97 100644 --- a/python/paddle/hapi/model.py +++ b/python/paddle/hapi/model.py @@ -201,8 +201,11 @@ def prepare_distributed_context(place=None): def _update_input_shapes(inputs): + "Get input shape list by given inputs in Model initialization." shapes = None - if isinstance(inputs, list): + if isinstance(inputs, Input): + shapes = [list(inputs.shape)] + elif isinstance(inputs, list): shapes = [list(input.shape) for input in inputs] elif isinstance(inputs, dict): shapes = [list(inputs[name].shape) for name in inputs] @@ -917,9 +920,7 @@ class Model(object): """ loss = self._adapter.train_batch(inputs, labels) if fluid.in_dygraph_mode() and self._input_shapes is None: - self._input_shapes = self._adapter._input_shapes - self._is_shape_inferred = True - self._inputs = self._verify_spec(None, self._input_shapes, True) + self._update_inputs() return loss def eval_batch(self, inputs, labels=None): @@ -967,9 +968,7 @@ class Model(object): """ loss = self._adapter.eval_batch(inputs, labels) if fluid.in_dygraph_mode() and self._input_shapes is None: - self._input_shapes = self._adapter._input_shapes - self._is_shape_inferred = True - self._inputs = self._verify_spec(None, self._input_shapes, True) + self._update_inputs() return loss def test_batch(self, inputs): @@ -1012,9 +1011,7 @@ class Model(object): """ loss = self._adapter.test_batch(inputs) if fluid.in_dygraph_mode() and self._input_shapes is None: - self._input_shapes = self._adapter._input_shapes - self._is_shape_inferred = True - self._inputs = self._verify_spec(None, self._input_shapes, True) + self._update_inputs() return loss def save(self, path, training=True): @@ -1707,7 +1704,7 @@ class Model(object): layer = self.network if self._input_shapes is None: # No provided or inferred raise RuntimeError( - "Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training zqqdata and perform a training for shape derivation." + "Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training data and perform a training for shape derivation." ) if self._is_shape_inferred: warnings.warn( @@ -1953,3 +1950,9 @@ class Model(object): except Exception: steps = None return steps + + def _update_inputs(self): + "Update self._inputs according to given inputs." + self._input_shapes = self._adapter._input_shapes + self._is_shape_inferred = True + self._inputs = self._verify_spec(None, self._input_shapes, True) diff --git a/python/paddle/tests/test_model.py b/python/paddle/tests/test_model.py index 8cd5e172aa06a6abada0a2532f02eb220add1c73..56105b6d7f15abdce2f137f188b3ff5c5deec8f0 100644 --- a/python/paddle/tests/test_model.py +++ b/python/paddle/tests/test_model.py @@ -556,9 +556,10 @@ class TestModelFunction(unittest.TestCase): shutil.rmtree(save_dir) paddle.enable_static() - def test_dygraph_export_deploy_model_without_inputs(self): + def test_dygraph_export_deploy_model_about_inputs(self): mnist_data = MnistDataset(mode='train') paddle.disable_static() + # without inputs for initial in ["fit", "train_batch", "eval_batch", "test_batch"]: save_dir = tempfile.mkdtemp() if not os.path.exists(save_dir): @@ -584,6 +585,18 @@ class TestModelFunction(unittest.TestCase): model.save(save_dir, training=False) shutil.rmtree(save_dir) + # with inputs, and the type of inputs is InputSpec + save_dir = tempfile.mkdtemp() + if not os.path.exists(save_dir): + os.makedirs(save_dir) + net = LeNet() + inputs = InputSpec([None, 1, 28, 28], 'float32', 'x') + model = Model(net, inputs) + optim = fluid.optimizer.Adam( + learning_rate=0.001, parameter_list=model.parameters()) + model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) + model.save(save_dir, training=False) + shutil.rmtree(save_dir) class TestRaiseError(unittest.TestCase):