未验证 提交 10af966a 编写于 作者: C CtfGo 提交者: GitHub

update the TraceLayer.save_inference_model method with add file suffix automatically (#31989)

As the title
上级 f5186c3c
......@@ -1244,13 +1244,16 @@ class TracedLayer(object):
return self._run(self._build_feed(inputs))
@switch_to_static_graph
def save_inference_model(self, dirname, feed=None, fetch=None):
def save_inference_model(self, path, feed=None, fetch=None):
"""
Save the TracedLayer to a model for inference. The saved
inference model can be loaded by C++ inference APIs.
``path`` is the prefix of saved objects, and the saved translated program file
suffix is ``.pdmodel`` , the saved persistable variables file suffix is ``.pdiparams`` .
Args:
dirname (str): the directory to save the inference model.
path(str): The path prefix to save model. The format is ``dirname/file_prefix`` or ``file_prefix``.
feed (list[int], optional): the input variable indices of the saved
inference model. If None, all input variables of the
TracedLayer object would be the inputs of the saved inference
......@@ -1294,7 +1297,7 @@ class TracedLayer(object):
fetch, = exe.run(program, feed={feed_vars[0]: in_np}, fetch_list=fetch_vars)
print(fetch.shape) # (2, 10)
"""
check_type(dirname, "dirname", str,
check_type(path, "path", str,
"fluid.dygraph.jit.TracedLayer.save_inference_model")
check_type(feed, "feed", (type(None), list),
"fluid.dygraph.jit.TracedLayer.save_inference_model")
......@@ -1309,6 +1312,18 @@ class TracedLayer(object):
check_type(f, "each element of fetch", int,
"fluid.dygraph.jit.TracedLayer.save_inference_model")
# path check
file_prefix = os.path.basename(path)
if file_prefix == "":
raise ValueError(
"The input path MUST be format of dirname/file_prefix "
"[dirname\\file_prefix in Windows system], but received "
"file_prefix is empty string.")
dirname = os.path.dirname(path)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
from paddle.fluid.io import save_inference_model
def get_feed_fetch(all_vars, partial_vars):
......@@ -1326,9 +1341,14 @@ class TracedLayer(object):
assert target_var is not None, "{} cannot be found".format(name)
target_vars.append(target_var)
model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX
save_inference_model(
dirname=dirname,
feeded_var_names=feeded_var_names,
target_vars=target_vars,
executor=self._exe,
main_program=self._program.clone())
main_program=self._program.clone(),
model_filename=model_filename,
params_filename=params_filename)
......@@ -75,10 +75,12 @@ class TestTracedLayerRecordNonPersistableInput(unittest.TestCase):
self.assertEqual(actual_persistable_vars, expected_persistable_vars)
dirname = './traced_layer_test_non_persistable_vars'
traced_layer.save_inference_model(dirname=dirname)
filenames = set([f for f in os.listdir(dirname) if f != '__model__'])
self.assertEqual(filenames, expected_persistable_vars)
traced_layer.save_inference_model(
path='./traced_layer_test_non_persistable_vars')
self.assertTrue('traced_layer_test_non_persistable_vars.pdmodel' in
os.listdir('./'))
self.assertTrue('traced_layer_test_non_persistable_vars.pdiparams' in
os.listdir('./'))
if __name__ == '__main__':
......
......@@ -18,6 +18,7 @@ import paddle.fluid as fluid
import six
import unittest
import paddle.nn as nn
import os
class SimpleFCLayer(nn.Layer):
......@@ -115,36 +116,41 @@ class TestTracedLayerErrMsg(unittest.TestCase):
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
self.layer, [in_x])
dirname = './traced_layer_err_msg'
path = './traced_layer_err_msg'
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model([0])
self.assertEqual(
"The type of 'dirname' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'str'>, but received <{} 'list'>. ".
"The type of 'path' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'str'>, but received <{} 'list'>. ".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, [0], [None])
traced_layer.save_inference_model(path, [0], [None])
self.assertEqual(
"The type of 'each element of fetch' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. ".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, [0], False)
traced_layer.save_inference_model(path, [0], False)
self.assertEqual(
"The type of 'fetch' in fluid.dygraph.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. ".
format(self.type_str, self.type_str, self.type_str),
str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, [None], [0])
traced_layer.save_inference_model(path, [None], [0])
self.assertEqual(
"The type of 'each element of feed' in fluid.dygraph.jit.TracedLayer.save_inference_model must be <{} 'int'>, but received <{} 'NoneType'>. ".
format(self.type_str, self.type_str), str(e.exception))
with self.assertRaises(TypeError) as e:
traced_layer.save_inference_model(dirname, True, [0])
traced_layer.save_inference_model(path, True, [0])
self.assertEqual(
"The type of 'feed' in fluid.dygraph.jit.TracedLayer.save_inference_model must be (<{} 'NoneType'>, <{} 'list'>), but received <{} 'bool'>. ".
format(self.type_str, self.type_str, self.type_str),
str(e.exception))
with self.assertRaises(ValueError) as e:
traced_layer.save_inference_model("")
self.assertEqual(
"The input path MUST be format of dirname/file_prefix [dirname\\file_prefix in Windows system], "
"but received file_prefix is empty string.", str(e.exception))
traced_layer.save_inference_model(dirname)
traced_layer.save_inference_model(path)
def _train_simple_net(self):
layer = None
......@@ -174,5 +180,25 @@ class TestOutVarWithNoneErrMsg(unittest.TestCase):
[in_x])
class TestTracedLayerSaveInferenceModel(unittest.TestCase):
"""test save_inference_model will automaticlly create non-exist dir"""
def setUp(self):
self.save_path = "./nonexist_dir/fc"
import shutil
if os.path.exists(os.path.dirname(self.save_path)):
shutil.rmtree(os.path.dirname(self.save_path))
def test_mkdir_when_input_path_non_exist(self):
fc_layer = SimpleFCLayer(3, 4, 2)
input_var = paddle.to_tensor(np.random.random([4, 3]).astype('float32'))
with fluid.dygraph.guard():
dygraph_out, traced_layer = fluid.dygraph.TracedLayer.trace(
fc_layer, inputs=[input_var])
self.assertFalse(os.path.exists(os.path.dirname(self.save_path)))
traced_layer.save_inference_model(self.save_path)
self.assertTrue(os.path.exists(os.path.dirname(self.save_path)))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册