未验证 提交 b84d4ae3 编写于 作者: L LiuChiachi 提交者: GitHub

Fix bug of Model.save (#27815)

* fix model bugs, inputs can be InputSpec instance

* correct error message
上级 f3e2580c
......@@ -201,8 +201,11 @@ def prepare_distributed_context(place=None):
def _update_input_shapes(inputs):
"Get input shape list by given inputs in Model initialization."
shapes = None
if isinstance(inputs, list):
if isinstance(inputs, Input):
shapes = [list(inputs.shape)]
elif isinstance(inputs, list):
shapes = [list(input.shape) for input in inputs]
elif isinstance(inputs, dict):
shapes = [list(inputs[name].shape) for name in inputs]
......@@ -917,9 +920,7 @@ class Model(object):
"""
loss = self._adapter.train_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
self._update_inputs()
return loss
def eval_batch(self, inputs, labels=None):
......@@ -967,9 +968,7 @@ class Model(object):
"""
loss = self._adapter.eval_batch(inputs, labels)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
self._update_inputs()
return loss
def test_batch(self, inputs):
......@@ -1012,9 +1011,7 @@ class Model(object):
"""
loss = self._adapter.test_batch(inputs)
if fluid.in_dygraph_mode() and self._input_shapes is None:
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
self._update_inputs()
return loss
def save(self, path, training=True):
......@@ -1707,7 +1704,7 @@ class Model(object):
layer = self.network
if self._input_shapes is None: # No provided or inferred
raise RuntimeError(
"Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training zqqdata and perform a training for shape derivation."
"Saving inference model needs 'inputs' or running before saving. Please specify 'inputs' in Model initialization or input training data and perform a training for shape derivation."
)
if self._is_shape_inferred:
warnings.warn(
......@@ -1953,3 +1950,9 @@ class Model(object):
except Exception:
steps = None
return steps
def _update_inputs(self):
"Update self._inputs according to given inputs."
self._input_shapes = self._adapter._input_shapes
self._is_shape_inferred = True
self._inputs = self._verify_spec(None, self._input_shapes, True)
......@@ -556,9 +556,10 @@ class TestModelFunction(unittest.TestCase):
shutil.rmtree(save_dir)
paddle.enable_static()
def test_dygraph_export_deploy_model_without_inputs(self):
def test_dygraph_export_deploy_model_about_inputs(self):
mnist_data = MnistDataset(mode='train')
paddle.disable_static()
# without inputs
for initial in ["fit", "train_batch", "eval_batch", "test_batch"]:
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
......@@ -584,6 +585,18 @@ class TestModelFunction(unittest.TestCase):
model.save(save_dir, training=False)
shutil.rmtree(save_dir)
# with inputs, and the type of inputs is InputSpec
save_dir = tempfile.mkdtemp()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
net = LeNet()
inputs = InputSpec([None, 1, 28, 28], 'float32', 'x')
model = Model(net, inputs)
optim = fluid.optimizer.Adam(
learning_rate=0.001, parameter_list=model.parameters())
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
model.save(save_dir, training=False)
shutil.rmtree(save_dir)
class TestRaiseError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册