未验证 提交 761ed17d 编写于 作者: L LiuChiachi 提交者: GitHub

Update save inference model to support dygraph (#25894)

* update save_inference_model for hapi

* update save_inference_model to support dygraph

* fix comments

* fix comments

* test=develop

* test, test=develop

* fix dim test, test=develop

* test, test=develop

* add test_export_deploy_model_dynamic

* fix unittest for hapi: save_inference_model

* fix code style

* accept review by guoshengCS

* fix coverage rate

* update doc for save_inference_model and copyright

* change test model back to LeNet() in test_export_deploy_model

* copy jit.save, use LeNet() to test export deploy model

* add return value for dygraph, and fix doc error

* corrected the doc writing

* Delete redundant import and correct import order in sample code.

* remove 'fluid' and add prepare() and fit() in sample code

* correct usage of API 2.0 in sample code

* fix sample code bugs

* fix code style bugs

* fix test_model.py bugs

* set for_inference=True

* correct usage for static.InputSpec

* update doc for model.save

* correct usage of API 2.0

* rename param name for model.save

* correct for_inference as training
上级 d32beea2
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -26,17 +26,22 @@ from collections import Iterable ...@@ -26,17 +26,22 @@ from collections import Iterable
import paddle import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core
from paddle.fluid.framework import in_dygraph_mode, Variable, ParamBase, _current_expected_place
# Note: Use alias `Input` temporarily before releasing hapi feature. # Note: Use alias `Input` temporarily before releasing hapi feature.
from paddle.static import InputSpec as Input from paddle.static import InputSpec as Input
from paddle.fluid.framework import in_dygraph_mode, Variable
from paddle.fluid.executor import global_scope from paddle.fluid.executor import global_scope
from paddle.fluid.io import is_belong_to_optimizer from paddle.fluid.io import is_belong_to_optimizer
from paddle.fluid.dygraph.base import to_variable from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator, FunctionSpec
from paddle.fluid.layers.utils import flatten from paddle.fluid.layers.utils import flatten
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.incubate.fleet.base import role_maker from paddle.fluid.incubate.fleet.base import role_maker
from paddle.fluid.executor import scope_guard, Executor
from paddle.io import DataLoader, Dataset from paddle.io import DataLoader, Dataset
from paddle.fluid.dygraph.layers import Layer
from paddle.metric import Metric from paddle.metric import Metric
from .distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized from .distributed import DistributedBatchSampler, _all_gather, prepare_distributed_context, _parallel_context_initialized
...@@ -846,24 +851,32 @@ class Model(object): ...@@ -846,24 +851,32 @@ class Model(object):
""" """
return self._adapter.test_batch(inputs) return self._adapter.test_batch(inputs)
def save(self, path): def save(self, path, training=True):
""" """
This function saves parameters, optimizer infomation to path. This function saves parameters, optimizer information or model and
paramters only for inference to path. It depends on the parameter
`training`.
The parameters contains all the trainable Variable, will save to If `training` is set to True, the parameters saved contain all
a file with suffix ".pdparams". the trainable Variable, will save to a file with suffix ".pdparams".
The optimizer information contains all the variable used by optimizer. The optimizer information contains all the variable used by optimizer.
For Adam optimizer, contains beta1, beta2, momentum etc. All the For Adam optimizer, contains beta1, beta2, momentum etc. All the
information will save to a file with suffix ".pdopt". (If the optimizer information will save to a file with suffix ".pdopt". (If the optimizer
have no variable need to save (like SGD), the fill will not generated). have no variable need to save (like SGD), the fill will not generated).
This function will silently overwrite existing file at the target location.
This function will silently overwrite existing file If `training` is set to False, only inference model will be saved. It
at the target location. should be noted that before using `save`, you should run the model, and
the shape of input you saved is as same as the input of its running.
`@paddle.jit.to_static` must be added on `forward` function of your layer
in dynamic mode now and these will be optimized later.
Args: Args:
path (str): The file prefix to save model. The format is path (str): The file prefix to save model. The format is
'dirname/file_prefix' or 'file_prefix'. if empty str. A exception 'dirname/file_prefix' or 'file_prefix'. if empty str. A exception
will be raised. will be raised.
training (bool, optional): Whether to save for training. If not, save
for inference only. Default: True.
Returns: Returns:
None None
...@@ -871,25 +884,47 @@ class Model(object): ...@@ -871,25 +884,47 @@ class Model(object):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle
import paddle import paddle.incubate.hapi as hapi
import paddle.incubate.hapi as hapi from paddle.nn import Linear
from paddle.incubate.hapi.datasets.mnist import MNIST as MnistDataset
class MyNet(paddle.nn.Layer):
def __init__(self): class Mnist(paddle.nn.Layer):
super(MyNet, self).__init__() def __init__(self):
self._fc = paddle.nn.Linear(784, 1, act='softmax') super(MyNet, self).__init__()
self._fc = Linear(784, 1, act='softmax')
@paddle.jit.to_static # If save for inference in dygraph, need this
def forward(self, x): def forward(self, x):
y = self._fc(x) y = self._fc(x)
return y return y
device = hapi.set_device('cpu') dynamic = True # False
paddle.disable_static(device) device = hapi.set_device('cpu')
model = hapi.Model(MyNet()) # if use static graph, do not set
model.save('checkpoint/test') paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(Mnist(), input, label)
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy())
mnist_data = hapi.datasets.MNIST(mode='train', chw_format=False)
model.fit(mnist_data, epochs=1, batch_size=32, verbose=0)
model.save('checkpoint/test') # save for training
model.save('inference_model', False) # save for inference
""" """
if ParallelEnv().local_rank == 0: if ParallelEnv().local_rank == 0:
self._adapter.save(path) if not training:
self._save_inference_model(path)
else:
self._adapter.save(path)
def load(self, path, skip_mismatch=False, reset_optimizer=False): def load(self, path, skip_mismatch=False, reset_optimizer=False):
""" """
...@@ -1474,13 +1509,17 @@ class Model(object): ...@@ -1474,13 +1509,17 @@ class Model(object):
cbks.on_end('test', logs) cbks.on_end('test', logs)
return outputs return outputs
def save_inference_model(self, def _save_inference_model(self,
save_dir, save_dir,
model_filename=None, model_filename=None,
params_filename=None, params_filename=None,
model_only=False): model_only=False):
""" """
Save inference model must in static mode. Save inference model can be in static or dynamic mode.
It should be noted that before using `save_inference_model`, you should
run the model, and the shape you saved is as same as the input of its
running. `@paddle.jit.to_static` must be added on `forward` function of
your layer in dynamic mode now and these will be optimized later.
Args: Args:
save_dir (str): The directory path to save the inference model. save_dir (str): The directory path to save the inference model.
...@@ -1496,39 +1535,145 @@ class Model(object): ...@@ -1496,39 +1535,145 @@ class Model(object):
Returns: Returns:
list: The fetch variables' name list list: The fetch variables' name list
Examples: Examples:
.. code-block:: python .. code-block:: python
import numpy as np
import paddle
from paddle.static import InputSpec
import paddle.incubate.hapi as hapi import paddle.incubate.hapi as hapi
from paddle.nn import Linear
input = hapi.Input([-1, 1, 28, 28], 'float32', 'image') from paddle.incubate.hapi.datasets.mnist import MNIST as MnistDataset
model = hapi.Model(hapi.vision.LeNet(), input)
model.prepare() class Mnist(Layer):
def __init__(self, classifier_act=None):
super(Mnist, self).__init__()
self.fc = Linear(input_dim=784, output_dim=10, act="softmax")
@paddle.jit.to_static # In static mode, you need to delete this.
def forward(self, inputs):
outputs = self.fc(inputs)
return outputs
dynamic = True # False
device = hapi.set_device('gpu')
# if use static graph, do not set
paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = hapi.Model(Mnist(), input, label)
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy())
mnist_data = hapi.datasets.MNIST(mode='train', chw_format=False)
model.fit(mnist_data, epochs=1, batch_size=32, verbose=0)
model.save_inference_model('inference_model') model.save_inference_model('inference_model')
""" """
assert not fluid.in_dygraph_mode(
), 'Save inference model must in static mode!'
prog = self._adapter._progs.get('test', None) def get_inout_spec(all_vars, return_name=False):
assert prog, \ result_list = []
"Model is not ready, please call `model.prepare()` first" valid_vars = [var for var in all_vars if isinstance(var, Variable)]
result_list = valid_vars
if return_name:
result_list = [var.name for var in result_list]
infer_prog = prog.clone(for_test=True) return result_list
input_names = [v.name for v in self._adapter._input_vars['test']] # TODO:
endpoints = self._adapter._endpoints['test']['output'] # 1. Make it Unnecessary to run model before calling `save_inference_model` for users in dygraph.
# 2. Save correct shape of input, now the interface stores the shape that the user sent to
# the inputs of the model in running.
# 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode.
if fluid.in_dygraph_mode():
layer = self.network
fluid.disable_dygraph()
# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_declarative:
raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable=False."
)
if not isinstance(layer, Layer):
raise TypeError(
"The input layer should be 'Layer', but received layer type is %s."
% type(layer))
# 2. get program of declarative Layer.forward
prog_cache = prog_translator.get_program_cache()
# make dummy args & kwargs, to get excepted FunctionSpec
layer_func = FunctionSpec(type(layer).forward, [layer], {})
concrete_program, _ = prog_cache.get_program(layer_func)
# NOTE: we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in layer.state_dict().items():
state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info
scope = core.Scope()
extra_var_info = dict()
for param_or_buffer in concrete_program.parameters:
# share to scope
param_or_buffer_tensor = scope.var(
param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# record var info
extra_info_dict = dict()
if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
# 4. build input & output spec
input_var_names = get_inout_spec(concrete_program.inputs, True)
output_vars = get_inout_spec(concrete_program.outputs)
# 5. save inference model
with scope_guard(scope):
return fluid.io.save_inference_model(
dirname=save_dir,
feeded_var_names=input_var_names,
target_vars=output_vars,
executor=Executor(_current_expected_place()),
main_program=concrete_program.main_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
program_only=model_only)
return fluid.io.save_inference_model( else:
save_dir, prog = self._adapter._progs.get('test', None)
input_names, assert prog, \
endpoints, "Model is not ready, please call `model.prepare()` first"
self._adapter._executor,
main_program=infer_prog, infer_prog = prog.clone(for_test=True)
model_filename=model_filename,
params_filename=params_filename, input_names = [v.name for v in self._adapter._input_vars['test']]
program_only=model_only) endpoints = self._adapter._endpoints['test']['output']
return fluid.io.save_inference_model(
save_dir,
input_names,
endpoints,
self._adapter._executor,
main_program=infer_prog,
model_filename=model_filename,
params_filename=params_filename,
program_only=model_only)
def _run_one_epoch(self, data_loader, callbacks, mode, logs={}): def _run_one_epoch(self, data_loader, callbacks, mode, logs={}):
outputs = [] outputs = []
......
...@@ -33,6 +33,8 @@ from paddle.metric import Accuracy ...@@ -33,6 +33,8 @@ from paddle.metric import Accuracy
from paddle.incubate.hapi.datasets import MNIST from paddle.incubate.hapi.datasets import MNIST
from paddle.incubate.hapi.vision.models import LeNet from paddle.incubate.hapi.vision.models import LeNet
from paddle.incubate.hapi.distributed import DistributedBatchSampler, prepare_distributed_context from paddle.incubate.hapi.distributed import DistributedBatchSampler, prepare_distributed_context
from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.program_translator import ProgramTranslator
class LeNetDygraph(fluid.dygraph.Layer): class LeNetDygraph(fluid.dygraph.Layer):
...@@ -65,6 +67,37 @@ class LeNetDygraph(fluid.dygraph.Layer): ...@@ -65,6 +67,37 @@ class LeNetDygraph(fluid.dygraph.Layer):
return x return x
class LeNetDeclarative(fluid.dygraph.Layer):
def __init__(self, num_classes=10, classifier_activation=None):
super(LeNetDeclarative, self).__init__()
self.num_classes = num_classes
self.features = Sequential(
Conv2d(
1, 6, 3, stride=1, padding=1),
ReLU(),
Pool2D(2, 'max', 2),
Conv2d(
6, 16, 5, stride=1, padding=0),
ReLU(),
Pool2D(2, 'max', 2))
if num_classes > 0:
self.fc = Sequential(
Linear(400, 120),
Linear(120, 84),
Linear(
84, 10, act=classifier_activation))
@declarative
def forward(self, inputs):
x = self.features(inputs)
if self.num_classes > 0:
x = fluid.layers.flatten(x, 1)
x = self.fc(x)
return x
class MnistDataset(MNIST): class MnistDataset(MNIST):
def __init__(self, mode, return_label=True, sample_num=None): def __init__(self, mode, return_label=True, sample_num=None):
super(MnistDataset, self).__init__(mode=mode) super(MnistDataset, self).__init__(mode=mode)
...@@ -335,7 +368,6 @@ class TestModelFunction(unittest.TestCase): ...@@ -335,7 +368,6 @@ class TestModelFunction(unittest.TestCase):
model = Model(net, inputs, labels) model = Model(net, inputs, labels)
model.prepare(optim2, loss=CrossEntropyLoss(reduction="sum")) model.prepare(optim2, loss=CrossEntropyLoss(reduction="sum"))
loss, = model.train_batch([data], [label]) loss, = model.train_batch([data], [label])
np.testing.assert_allclose(loss.flatten(), ref.flatten()) np.testing.assert_allclose(loss.flatten(), ref.flatten())
fluid.disable_dygraph() if dynamic else None fluid.disable_dygraph() if dynamic else None
...@@ -445,33 +477,38 @@ class TestModelFunction(unittest.TestCase): ...@@ -445,33 +477,38 @@ class TestModelFunction(unittest.TestCase):
fluid.disable_dygraph() if dynamic else None fluid.disable_dygraph() if dynamic else None
def test_export_deploy_model(self): def test_export_deploy_model(self):
net = LeNet() for dynamic in [True, False]:
inputs = [Input([-1, 1, 28, 28], 'float32', 'image')] fluid.enable_dygraph() if dynamic else None
model = Model(net, inputs) # paddle.disable_static() if dynamic else None
model.prepare() prog_translator = ProgramTranslator()
save_dir = tempfile.mkdtemp() prog_translator.enable(False) if not dynamic else None
if not os.path.exists(save_dir): net = LeNetDeclarative()
os.makedirs(save_dir) inputs = [Input([None, 1, 28, 28], 'float32', 'x')]
model = Model(net, inputs)
tensor_img = np.array( model.prepare()
np.random.random((1, 1, 28, 28)), dtype=np.float32) save_dir = tempfile.mkdtemp()
ori_results = model.test_batch(tensor_img) if not os.path.exists(save_dir):
os.makedirs(save_dir)
model.save_inference_model(save_dir) tensor_img = np.array(
np.random.random((1, 1, 28, 28)), dtype=np.float32)
place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda( ori_results = model.test_batch(tensor_img)
) else fluid.CUDAPlace(0) model.save(save_dir, training=False)
exe = fluid.Executor(place) fluid.disable_dygraph() if dynamic else None
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=save_dir, executor=exe))
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
np.testing.assert_allclose(results, ori_results, rtol=1e-6) place = fluid.CPUPlace() if not fluid.is_compiled_with_cuda(
shutil.rmtree(save_dir) ) else fluid.CUDAPlace(0)
new_scope = fluid.Scope()
with fluid.scope_guard(new_scope):
exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = (
fluid.io.load_inference_model(
dirname=save_dir, executor=exe))
results = exe.run(inference_program,
feed={feed_target_names[0]: tensor_img},
fetch_list=fetch_targets)
np.testing.assert_allclose(
results, ori_results, rtol=1e-5, atol=1e-7)
shutil.rmtree(save_dir)
class TestRaiseError(unittest.TestCase): class TestRaiseError(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册