未验证 提交 25dad426 编写于 作者: L LiuChiachi 提交者: GitHub

fix sample code for hapi.model.save (#26667)

* fix sample code for hapi.model.save, test=document_fix

* test=document_fix

* update usage of 2.0 API, test=document_fix

* fix bugs, return dygraph back to users while using model.save in dygraph

* fix code style
上级 7afb1df1
......@@ -891,33 +891,31 @@ class Model(object):
class Mnist(paddle.nn.Layer):
def __init__(self):
super(MyNet, self).__init__()
self._fc = Linear(784, 1, act='softmax')
super(Mnist, self).__init__()
self._fc = Linear(784, 10, act='softmax')
@paddle.jit.to_static # If save for inference in dygraph, need this
def forward(self, x):
y = self._fc(x)
return y
# If save for inference in dygraph, need this
@paddle.jit.to_static
def forward(self, x):
y = self._fc(x)
return y
dynamic = True # False
dynamic = True # False
device = hapi.set_device('cpu')
# if use static graph, do not set
paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = hapi.Input([None, 784], 'float32', 'x')
label = hapi.Input([None, 1], 'int64', 'label')
model = hapi.Model(Mnist(), input, label)
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy())
parameter_list=model.parameters())
model.prepare(optim, paddle.nn.CrossEntropyLoss())
mnist_data = hapi.datasets.MNIST(mode='train', chw_format=False)
model.fit(mnist_data, epochs=1, batch_size=32, verbose=0)
model.save('checkpoint/test') # save for training
model.save('inference_model', False) # save for inference
model.save('checkpoint/test') # save for training
model.save('inference_model', False) # save for inference
"""
if ParallelEnv().local_rank == 0:
......@@ -1534,47 +1532,6 @@ class Model(object):
Returns:
list: The fetch variables' name list
Examples:
.. code-block:: python
import numpy as np
import paddle
from paddle.static import InputSpec
import paddle.incubate.hapi as hapi
from paddle.nn import Linear
from paddle.incubate.hapi.datasets.mnist import MNIST as MnistDataset
class Mnist(Layer):
def __init__(self, classifier_act=None):
super(Mnist, self).__init__()
self.fc = Linear(input_dim=784, output_dim=10, act="softmax")
@paddle.jit.to_static # In static mode, you need to delete this.
def forward(self, inputs):
outputs = self.fc(inputs)
return outputs
dynamic = True # False
device = hapi.set_device('gpu')
# if use static graph, do not set
paddle.disable_static(device) if dynamic else None
# inputs and labels are not required for dynamic graph.
input = InputSpec([None, 784], 'float32', 'x')
label = InputSpec([None, 1], 'int64', 'label')
model = hapi.Model(Mnist(), input, label)
optim = paddle.optimizer.SGD(learning_rate=1e-3,
parameter_list=model.parameters())
model.prepare(optim,
paddle.nn.CrossEntropyLoss(),
hapi.metrics.Accuracy())
mnist_data = hapi.datasets.MNIST(mode='train', chw_format=False)
model.fit(mnist_data, epochs=1, batch_size=32, verbose=0)
model.save_inference_model('inference_model')
"""
def get_inout_spec(all_vars, return_name=False):
......@@ -1592,65 +1549,66 @@ class Model(object):
# the inputs of the model in running.
# 3. Make it Unnecessary to add `@paddle.jit.to_static` for users in dynamic mode.
if fluid.in_dygraph_mode():
layer = self.network
fluid.disable_dygraph()
# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_declarative:
raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable=False."
)
if not isinstance(layer, Layer):
raise TypeError(
"The input layer should be 'Layer', but received layer type is %s."
% type(layer))
# 2. get program of declarative Layer.forward
concrete_program = layer.forward.concrete_program
# NOTE: we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in layer.state_dict().items():
state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info
scope = core.Scope()
extra_var_info = dict()
for param_or_buffer in concrete_program.parameters:
# share to scope
param_or_buffer_tensor = scope.var(
param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# record var info
extra_info_dict = dict()
if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict['stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
# 4. build input & output spec
input_var_names = get_inout_spec(concrete_program.inputs, True)
output_vars = get_inout_spec(concrete_program.outputs)
# 5. save inference model
with scope_guard(scope):
return fluid.io.save_inference_model(
dirname=save_dir,
feeded_var_names=input_var_names,
target_vars=output_vars,
executor=Executor(_current_expected_place()),
main_program=concrete_program.main_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
program_only=model_only)
with fluid.framework._dygraph_guard(None):
layer = self.network
# 1. input check
prog_translator = ProgramTranslator()
if not prog_translator.enable_declarative:
raise RuntimeError(
"save_inference_model doesn't work when setting ProgramTranslator.enable=False."
)
if not isinstance(layer, Layer):
raise TypeError(
"The input layer should be 'Layer', but received layer type is %s."
% type(layer))
# 2. get program of declarative Layer.forward
concrete_program = layer.forward.concrete_program
# NOTE: we maintain the mapping of variable name to
# structured name, the buffer variable (non-persistable)
# saved to inference program may not need by dygraph Layer,
# we only record the state_dict variable's structured name
state_names_dict = dict()
for structured_name, var in layer.state_dict().items():
state_names_dict[var.name] = structured_name
# 3. share parameters from Layer to scope & record var info
scope = core.Scope()
extra_var_info = dict()
for param_or_buffer in concrete_program.parameters:
# share to scope
param_or_buffer_tensor = scope.var(
param_or_buffer.name).get_tensor()
src_tensor = param_or_buffer.value().get_tensor()
param_or_buffer_tensor._share_data_with(src_tensor)
# record var info
extra_info_dict = dict()
if param_or_buffer.name in state_names_dict:
extra_info_dict['structured_name'] = state_names_dict[
param_or_buffer.name]
extra_info_dict[
'stop_gradient'] = param_or_buffer.stop_gradient
if isinstance(param_or_buffer, ParamBase):
extra_info_dict['trainable'] = param_or_buffer.trainable
extra_var_info[param_or_buffer.name] = extra_info_dict
# 4. build input & output spec
input_var_names = get_inout_spec(concrete_program.inputs, True)
output_vars = get_inout_spec(concrete_program.outputs)
# 5. save inference model
with scope_guard(scope):
return fluid.io.save_inference_model(
dirname=save_dir,
feeded_var_names=input_var_names,
target_vars=output_vars,
executor=Executor(_current_expected_place()),
main_program=concrete_program.main_program.clone(),
model_filename=model_filename,
params_filename=params_filename,
program_only=model_only)
else:
prog = self._adapter._progs.get('test', None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册