未验证 提交 69279207 编写于 作者: L LiuChiachi 提交者: GitHub

Update hapi.model._save_inference_model by using new features of dy2stat in 2.0-beta API (#27272)

* update model.save_inference_model

* update doc for _save_inference_model, delete useless class in unittests

* make users not be able to set model._inputs be None

* update usage of Model class in unittests

* fix bugs of _verify_spec

* fix bugs of _verify_spec

* add unittest to increase coverage rate

* delete http.log

* update doc for save, remove requirments and limitations for using

* update doc for class Model
上级 f0a5eef5
...@@ -792,15 +792,14 @@ class Model(object): ...@@ -792,15 +792,14 @@ class Model(object):
switched by `paddle.disable_static()`. The usage is as follows. switched by `paddle.disable_static()`. The usage is as follows.
But note, the switching between dynamic and static should be before But note, the switching between dynamic and static should be before
instantiating a Model. The input description, i.e, paddle.static.InputSpec, instantiating a Model. The input description, i.e, paddle.static.InputSpec,
must be required for static graph. must be required.
Args: Args:
network (paddle.nn.Layer): The network is an instance of network (paddle.nn.Layer): The network is an instance of
paddle.nn.Layer. paddle.nn.Layer.
inputs (InputSpec|list|dict|None): `inputs`, entry points of network, inputs (InputSpec|list|dict|None): `inputs`, entry points of network,
could be a InputSpec instance, or lits of InputSpec instances, could be a InputSpec instance, or lits of InputSpec instances,
or dict ({name: InputSpec}), or None. For static graph, or dict ({name: InputSpec}), and it couldn't be None.
inputs must be set. For dynamic graph, it could be None.
labels (InputSpec|list|None): `labels`, entry points of network, labels (InputSpec|list|None): `labels`, entry points of network,
could be a InputSpec instnace or lits of InputSpec instances, could be a InputSpec instnace or lits of InputSpec instances,
or None. For static graph, if labels is required in loss, or None. For static graph, if labels is required in loss,
...@@ -849,10 +848,9 @@ class Model(object): ...@@ -849,10 +848,9 @@ class Model(object):
self._optimizer = None self._optimizer = None
self._test_dataloader = None self._test_dataloader = None
if not in_dygraph_mode(): if not isinstance(inputs, (list, dict, Input)):
if not isinstance(inputs, (list, dict, Input)): raise TypeError(
raise TypeError( "'inputs' must be list or dict, and couldn't be None.")
"'inputs' must be list or dict in static graph mode")
self._inputs = self._verify_spec(inputs, True) self._inputs = self._verify_spec(inputs, True)
self._labels = self._verify_spec(labels) self._labels = self._verify_spec(labels)
...@@ -1004,11 +1002,7 @@ class Model(object): ...@@ -1004,11 +1002,7 @@ class Model(object):
have no variable need to save (like SGD), the fill will not generated). have no variable need to save (like SGD), the fill will not generated).
This function will silently overwrite existing file at the target location. This function will silently overwrite existing file at the target location.
If `training` is set to False, only inference model will be saved. It If `training` is set to False, only inference model will be saved.
should be noted that before using `save`, you should run the model, and
the shape of input you saved is as same as the input of its running.
`@paddle.jit.to_static` must be added on `forward` function of your layer
in dynamic mode now and these will be optimized later.
Args: Args:
path (str): The file prefix to save model. The format is path (str): The file prefix to save model. The format is
...@@ -1037,8 +1031,6 @@ class Model(object): ...@@ -1037,8 +1031,6 @@ class Model(object):
nn.Linear(200, 10), nn.Linear(200, 10),
nn.Softmax()) nn.Softmax())
# If save for inference in dygraph, need this
@paddle.jit.to_static
def forward(self, x): def forward(self, x):
return self.net(x) return self.net(x)
...@@ -1046,7 +1038,7 @@ class Model(object): ...@@ -1046,7 +1038,7 @@ class Model(object):
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
# if use static graph, do not set # if use static graph, do not set
paddle.disable_static(device) if dynamic else None paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x') input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label') label = InputSpec([None, 1], 'int64', 'label')
model = paddle.Model(Mnist(), input, label) model = paddle.Model(Mnist(), input, label)
...@@ -1649,10 +1641,6 @@ class Model(object): ...@@ -1649,10 +1641,6 @@ class Model(object):
model_only=False): model_only=False):
""" """
Save inference model can be in static or dynamic mode. Save inference model can be in static or dynamic mode.
It should be noted that before using `save_inference_model`, you should
run the model, and the shape you saved is as same as the input of its
running. `@paddle.jit.to_static` must be added on `forward` function of
your layer in dynamic mode now and these will be optimized later.
Args: Args:
save_dir (str): The directory path to save the inference model. save_dir (str): The directory path to save the inference model.
...@@ -1678,14 +1666,11 @@ class Model(object): ...@@ -1678,14 +1666,11 @@ class Model(object):
return result_list return result_list
# TODO:
# 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph.
# 2. Save correct shape of input, now the interface stores the shape that the user sent to
# the inputs of the model in running.
# 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode.
if fluid.in_dygraph_mode(): if fluid.in_dygraph_mode():
with fluid.framework._dygraph_guard(None): with fluid.framework._dygraph_guard(None):
layer = self.network layer = self.network
layer.forward = paddle.jit.to_static(
layer.forward, input_spec=self._inputs)
# 1. input check # 1. input check
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
...@@ -1879,18 +1864,7 @@ class Model(object): ...@@ -1879,18 +1864,7 @@ class Model(object):
def _verify_spec(self, specs, is_input=False): def _verify_spec(self, specs, is_input=False):
out_specs = [] out_specs = []
if specs is None: if isinstance(specs, dict):
# Note(Aurelius84): If not specific specs of `Input`, using argument names of `forward` function
# to generate `Input`. But how can we know the actual shape of each input tensor?
if is_input:
out_specs = [
Input(
name=n, shape=[None])
for n in extract_args(self.network.forward) if n != 'self'
]
else:
out_specs = to_list(specs)
elif isinstance(specs, dict):
assert is_input == False assert is_input == False
out_specs = [specs[n] \ out_specs = [specs[n] \
for n in extract_args(self.network.forward) if n != 'self'] for n in extract_args(self.network.forward) if n != 'self']
......
...@@ -67,35 +67,6 @@ class LeNetDygraph(paddle.nn.Layer): ...@@ -67,35 +67,6 @@ class LeNetDygraph(paddle.nn.Layer):
return x return x
class LeNetDeclarative(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation=None):
super(LeNetDeclarative, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
if num_classes > 0:
self.fc = Sequential(
Linear(400, 120), Linear(120, 84), Linear(84, 10),
Softmax()) #Todo: accept any activation
@declarative
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class MnistDataset(MNIST): class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None): def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode) super(MnistDataset, self).__init__(mode=mode)
...@@ -444,7 +415,9 @@ class TestModelFunction(unittest.TestCase): ...@@ -444,7 +415,9 @@ class TestModelFunction(unittest.TestCase):
# dynamic saving # dynamic saving
device = paddle.set_device('cpu') device = paddle.set_device('cpu')
fluid.enable_dygraph(device) fluid.enable_dygraph(device)
model = Model(MyModel(classifier_activation=None)) inputs = [InputSpec([None, 20], 'float32', 'x')]
labels = [InputSpec([None, 1], 'int64', 'label')]
model = Model(MyModel(classifier_activation=None), inputs, labels)
optim = fluid.optimizer.SGD(learning_rate=0.001, optim = fluid.optimizer.SGD(learning_rate=0.001,
parameter_list=model.parameters()) parameter_list=model.parameters())
model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum")) model.prepare(optimizer=optim, loss=CrossEntropyLoss(reduction="sum"))
...@@ -543,11 +516,10 @@ class TestModelFunction(unittest.TestCase): ...@@ -543,11 +516,10 @@ class TestModelFunction(unittest.TestCase):
def test_export_deploy_model(self): def test_export_deploy_model(self):
for dynamic in [True, False]: for dynamic in [True, False]:
fluid.enable_dygraph() if dynamic else None paddle.disable_static() if dynamic else None
# paddle.disable_static() if dynamic else None
prog_translator = ProgramTranslator() prog_translator = ProgramTranslator()
prog_translator.enable(False) if not dynamic else None prog_translator.enable(False) if not dynamic else None
net = LeNetDeclarative() net = LeNet()
inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')] inputs = [InputSpec([None, 1, 28, 28], 'float32', 'x')]
model = Model(net, inputs) model = Model(net, inputs)
model.prepare() model.prepare()
...@@ -556,8 +528,9 @@ class TestModelFunction(unittest.TestCase): ...@@ -556,8 +528,9 @@ class TestModelFunction(unittest.TestCase):
os.makedirs(save_dir) os.makedirs(save_dir)
tensor_img = np.array( tensor_img = np.array(
np.random.random((1, 1, 28, 28)), dtype=np.float32) np.random.random((1, 1, 28, 28)), dtype=np.float32)
ori_results = model.test_batch(tensor_img)
model.save(save_dir, training=False) model.save(save_dir, training=False)
ori_results = model.test_batch(tensor_img)
fluid.disable_dygraph() if dynamic else None fluid.disable_dygraph() if dynamic else None
place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda( place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda(
...@@ -574,6 +547,7 @@ class TestModelFunction(unittest.TestCase): ...@@ -574,6 +547,7 @@ class TestModelFunction(unittest.TestCase):
np.testing.assert_allclose( np.testing.assert_allclose(
results, ori_results, rtol=1e-5, atol=1e-7) results, ori_results, rtol=1e-5, atol=1e-7)
shutil.rmtree(save_dir) shutil.rmtree(save_dir)
paddle.enable_static()
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
...@@ -585,6 +559,14 @@ class TestRaiseError(unittest.TestCase): ...@@ -585,6 +559,14 @@ class TestRaiseError(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model = Model(net, inputs, labels) model = Model(net, inputs, labels)
def test_input_without_input_spec(self):
for dynamic in [True, False]:
paddle.disable_static() if dynamic else None
net = MyModel(classifier_activation=None)
with self.assertRaises(TypeError):
model = Model(net)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册