未验证 提交 b84d4ae3 编写于 作者: L LiuChiachi 提交者: GitHub

Fix bug of Model.save (#27815)

* fix model bugs, inputs can be InputSpec instance

* correct error message
上级 f3e2580c
...@@ -201,8 +201,11 @@ def prepare_distributed_context(place=None): ...@@ -201,8 +201,11 @@ def prepare_distributed_context(place=None):
def _update_input_shapes(inputs): def _update_input_shapes(inputs):
"Get input shape list by given inputs in Model initialization."
shapes = None shapes = None
if isinstance(inputs, list): if isinstance(inputs, Input):
shapes = [list(inputs.shape)]
elif isinstance(inputs, list):
shapes = [list(input.shape) for input in inputs] shapes = [list(input.shape) for input in inputs]
elif isinstance(inputs, dict): elif isinstance(inputs, dict):
shapes = [list(inputs[name].shape) for name in inputs] shapes = [list(inputs[name].shape) for name in inputs]
...@@ -917,9 +920,7 @@ class Model(object): ...@@ -917,9 +920,7 @@ class Model(object):
""" """
loss = self._adapter.train_batch(inputs, labels) loss = self._adapter.train_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None: if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes self._update_inputs()
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
return loss return loss
def eval_batch(self, inputs, labels=None): def eval_batch(self, inputs, labels=None):
...@@ -967,9 +968,7 @@ class Model(object): ...@@ -967,9 +968,7 @@ class Model(object):
""" """
loss = self._adapter.eval_batch(inputs, labels) loss = self._adapter.eval_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None: if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes self._update_inputs()
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
return loss return loss
def test_batch(self, inputs): def test_batch(self, inputs):
...@@ -1012,9 +1011,7 @@ class Model(object): ...@@ -1012,9 +1011,7 @@ class Model(object):
""" """
loss = self._adapter.test_batch(inputs) loss = self._adapter.test_batch(inputs)
if fluid.in_dygraph_mode() and self._input_shapes is None: if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes self._update_inputs()
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
return loss return loss
def save(self, path, training=True): def save(self, path, training=True):
...@@ -1707,7 +1704,7 @@ class Model(object): ...@@ -1707,7 +1704,7 @@ class Model(object):
layer = self.network layer = self.network
if self._input_shapes is None: # No provided or inferred if self._input_shapes is None: # No provided or inferred
raise RuntimeError( raise RuntimeError(
"Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training zqqdata and perform a training for shape derivation." "Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training data and perform a training for shape derivation."
) )
if self._is_shape_inferred: if self._is_shape_inferred:
warnings.warn( warnings.warn(
...@@ -1953,3 +1950,9 @@ class Model(object): ...@@ -1953,3 +1950,9 @@ class Model(object):
except Exception: except Exception:
steps = None steps = None
return steps return steps
def _update_inputs(self):
"Update self._inputs according to given inputs."
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
...@@ -556,9 +556,10 @@ class TestModelFunction(unittest.TestCase): ...@@ -556,9 +556,10 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
paddle.enable_static() paddle.enable_static()
def test_dygraph_export_deploy_model_without_inputs(self): def test_dygraph_export_deploy_model_about_inputs(self):
mnist_data = MnistDataset(mode='train') mnist_data = MnistDataset(mode='train')
paddle.disable_static() paddle.disable_static()
# without inputs
for initial in ["fit", "train_batch", "eval_batch", "test_batch"]: for initial in ["fit", "train_batch", "eval_batch", "test_batch"]:
save_dir = tempfile.mkdtemp() save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir): if not os.path.exists(save_dir):
...@@ -584,6 +585,18 @@ class TestModelFunction(unittest.TestCase): ...@@ -584,6 +585,18 @@ class TestModelFunction(unittest.TestCase):
model.save(save_dir, training=False) model.save(save_dir, training=False)
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
# with inputs, and the type of inputs is InputSpec
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
net = LeNet()
inputs = InputSpec([None, 1, 28, 28], 'float32', 'x')
model = Model(net, inputs)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.save(save_dir, training=False)
shutil.rmtree(save_dir)
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册