未验证 提交 9b49f024 编写于 作者: C Chen Weihang 提交者: GitHub

Polish jit.save/load design & remove paddle.SaveLoadConfig (#27623)

* replace config by kwargs

* change save path form dir to prefix

* fix failed unittests

* revert unittest name change

* polish en docs

* add more tests for coverage
上级 74d3a550
...@@ -235,7 +235,6 @@ from .framework import grad #DEFINE_ALIAS ...@@ -235,7 +235,6 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS from .framework import NoamDecay #DEFINE_ALIAS
......
...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.nn import Conv2D ...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
paddle.enable_static() paddle.enable_static()
...@@ -231,10 +232,11 @@ class TestImperativeQat(unittest.TestCase): ...@@ -231,10 +232,11 @@ class TestImperativeQat(unittest.TestCase):
before_save = lenet(test_img) before_save = lenet(test_img)
# save inference quantized model # save inference quantized model
path = "./mnist_infer_model" path = "./qat_infer_model/lenet"
save_dir = "./qat_infer_model"
paddle.jit.save( paddle.jit.save(
layer=lenet, layer=lenet,
model_path=path, path=path,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32') shape=[None, 1, 28, 28], dtype='float32')
...@@ -245,12 +247,12 @@ class TestImperativeQat(unittest.TestCase): ...@@ -245,12 +247,12 @@ class TestImperativeQat(unittest.TestCase):
else: else:
place = core.CPUPlace() place = core.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
dirname=path, dirname=save_dir,
executor=exe, executor=exe,
model_filename="__model__", model_filename="lenet" + INFER_MODEL_SUFFIX,
params_filename="__variables__")) params_filename="lenet" + INFER_PARAMS_SUFFIX)
after_save, = exe.run(inference_program, after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data}, feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets) fetch_list=fetch_targets)
...@@ -339,7 +341,7 @@ class TestImperativeQat(unittest.TestCase): ...@@ -339,7 +341,7 @@ class TestImperativeQat(unittest.TestCase):
paddle.jit.save( paddle.jit.save(
layer=lenet, layer=lenet,
model_path="./dynamic_mnist", path="./dynamic_mnist/model",
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32') shape=[None, 1, 28, 28], dtype='float32')
......
...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.nn import Conv2D ...@@ -31,6 +31,7 @@ from paddle.fluid.dygraph.nn import Conv2D
from paddle.fluid.dygraph.nn import Pool2D from paddle.fluid.dygraph.nn import Pool2D
from paddle.fluid.dygraph.nn import Linear from paddle.fluid.dygraph.nn import Linear
from paddle.fluid.log_helper import get_logger from paddle.fluid.log_helper import get_logger
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
paddle.enable_static() paddle.enable_static()
...@@ -231,10 +232,11 @@ class TestImperativeQat(unittest.TestCase): ...@@ -231,10 +232,11 @@ class TestImperativeQat(unittest.TestCase):
before_save = lenet(test_img) before_save = lenet(test_img)
# save inference quantized model # save inference quantized model
path = "./mnist_infer_model" path = "./qat_infer_model/mnist"
save_dir = "./qat_infer_model"
paddle.jit.save( paddle.jit.save(
layer=lenet, layer=lenet,
model_path=path, path=path,
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32') shape=[None, 1, 28, 28], dtype='float32')
...@@ -245,12 +247,12 @@ class TestImperativeQat(unittest.TestCase): ...@@ -245,12 +247,12 @@ class TestImperativeQat(unittest.TestCase):
else: else:
place = core.CPUPlace() place = core.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
dirname=path, dirname=save_dir,
executor=exe, executor=exe,
model_filename="__model__", model_filename="mnist" + INFER_MODEL_SUFFIX,
params_filename="__variables__")) params_filename="mnist" + INFER_PARAMS_SUFFIX)
after_save, = exe.run(inference_program, after_save, = exe.run(inference_program,
feed={feed_target_names[0]: test_data}, feed={feed_target_names[0]: test_data},
fetch_list=fetch_targets) fetch_list=fetch_targets)
...@@ -339,7 +341,7 @@ class TestImperativeQat(unittest.TestCase): ...@@ -339,7 +341,7 @@ class TestImperativeQat(unittest.TestCase):
paddle.jit.save( paddle.jit.save(
layer=lenet, layer=lenet,
model_path="./dynamic_mnist", path="./dynamic_mnist/model",
input_spec=[ input_spec=[
paddle.static.InputSpec( paddle.static.InputSpec(
shape=[None, 1, 28, 28], dtype='float32') shape=[None, 1, 28, 28], dtype='float32')
......
...@@ -24,8 +24,8 @@ from . import learning_rate_scheduler ...@@ -24,8 +24,8 @@ from . import learning_rate_scheduler
import warnings import warnings
from .. import core from .. import core
from .base import guard from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
__all__ = [ __all__ = [
'save_dygraph', 'save_dygraph',
...@@ -33,35 +33,23 @@ __all__ = [ ...@@ -33,35 +33,23 @@ __all__ = [
] ]
# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table, def _parse_load_config(configs):
# ensure compatibility when user still use keep_name_table argument supported_configs = ['model_filename', 'params_filename', 'keep_name_table']
def deprecate_keep_name_table(func):
@functools.wraps(func) # input check
def wrapper(*args, **kwargs): for key in configs:
def __warn_and_build_configs__(keep_name_table): if key not in supported_configs:
warnings.warn( raise ValueError(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.", "The additional config (%s) of `paddle.fluid.load_dygraph` is not supported."
DeprecationWarning) % (key))
config = SaveLoadConfig()
config.keep_name_table = keep_name_table
return config
# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass
return func(*args, **kwargs) # construct inner config
inner_config = _SaveLoadConfig()
inner_config.model_filename = configs.get('model_filename', None)
inner_config.params_filename = configs.get('params_filename', None)
inner_config.keep_name_table = configs.get('keep_name_table', None)
return wrapper return inner_config
@dygraph_only @dygraph_only
...@@ -132,12 +120,12 @@ def save_dygraph(state_dict, model_path): ...@@ -132,12 +120,12 @@ def save_dygraph(state_dict, model_path):
pickle.dump(model_dict, f, protocol=2) pickle.dump(model_dict, f, protocol=2)
# NOTE(chenweihang): load_dygraph will deprecated in future, we don't
# support new loading features for it
# TODO(qingqing01): remove dygraph_only to support loading static model. # TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready. # maybe need to unify the loading interface after 2.0 API is ready.
# @dygraph_only # @dygraph_only
@deprecate_save_load_configs def load_dygraph(model_path, **configs):
@deprecate_keep_name_table
def load_dygraph(model_path, config=None):
''' '''
:api_attr: imperative :api_attr: imperative
...@@ -152,10 +140,13 @@ def load_dygraph(model_path, config=None): ...@@ -152,10 +140,13 @@ def load_dygraph(model_path, config=None):
Args: Args:
model_path(str) : The file prefix store the state_dict. model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams') (The path should Not contain suffix '.pdparams')
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` **configs (dict, optional): other save configuration options for compatibility. We do not
object that specifies additional configuration options, these options recommend using these configurations, if not necessary, DO NOT use them. Default None.
are for compatibility with ``jit.save/io.save_inference_model`` formats. The following options are currently supported:
Default None. (1) model_filename (string): The inference model file name of the paddle 1.x ``save_inference_model``
save format. Default file name is :code:`__model__` .
(2) params_filename (string): The persistable variables file name of the paddle 1.x ``save_inference_model``
save format. No default file name, save variables separately by default.
Returns: Returns:
state_dict(dict) : the dict store the state_dict state_dict(dict) : the dict store the state_dict
...@@ -196,8 +187,7 @@ def load_dygraph(model_path, config=None): ...@@ -196,8 +187,7 @@ def load_dygraph(model_path, config=None):
opti_file_path = model_prefix + ".pdopt" opti_file_path = model_prefix + ".pdopt"
# deal with argument `config` # deal with argument `config`
if config is None: config = _parse_load_config(configs)
config = SaveLoadConfig()
if os.path.exists(params_file_path) or os.path.exists(opti_file_path): if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
# Load state dict by `save_dygraph` save format # Load state dict by `save_dygraph` save format
...@@ -246,7 +236,6 @@ def load_dygraph(model_path, config=None): ...@@ -246,7 +236,6 @@ def load_dygraph(model_path, config=None):
persistable_var_dict = _construct_params_and_buffers( persistable_var_dict = _construct_params_and_buffers(
model_prefix, model_prefix,
programs, programs,
config.separate_params,
config.params_filename, config.params_filename,
append_suffix=False) append_suffix=False)
...@@ -255,9 +244,9 @@ def load_dygraph(model_path, config=None): ...@@ -255,9 +244,9 @@ def load_dygraph(model_path, config=None):
for var_name in persistable_var_dict: for var_name in persistable_var_dict:
para_dict[var_name] = persistable_var_dict[var_name].numpy() para_dict[var_name] = persistable_var_dict[var_name].numpy()
# if __variables.info__ exists, we can recover structured_name # if *.info exists, we can recover structured_name
var_info_path = os.path.join(model_prefix, var_info_filename = str(config.params_filename) + ".info"
EXTRA_VAR_INFO_FILENAME) var_info_path = os.path.join(model_prefix, var_info_filename)
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f: with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) extra_var_info = pickle.load(f)
......
...@@ -31,8 +31,10 @@ from paddle.fluid.dygraph.base import switch_to_static_graph ...@@ -31,8 +31,10 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
__all__ = ['TranslatedLayer'] __all__ = ['TranslatedLayer']
VARIABLE_FILENAME = "__variables__" INFER_MODEL_SUFFIX = ".pdmodel"
EXTRA_VAR_INFO_FILENAME = "__variables.info__" INFER_PARAMS_SUFFIX = ".pdiparams"
INFER_PARAMS_INFO_SUFFIX = ".pdiparams.info"
LOADED_VAR_SUFFIX = "load" LOADED_VAR_SUFFIX = "load"
PARAMETER_NAME_PREFIX = "param" PARAMETER_NAME_PREFIX = "param"
BUFFER_NAME_PREFIX = "buffer" BUFFER_NAME_PREFIX = "buffer"
...@@ -424,11 +426,8 @@ def _load_persistable_vars_by_program(model_path, ...@@ -424,11 +426,8 @@ def _load_persistable_vars_by_program(model_path,
return load_var_dict return load_var_dict
def _load_persistable_vars(model_path, def _load_persistable_vars(model_path, var_info_path, program_holder,
var_info_path, params_filename):
program_holder,
separate_params=False,
params_filename=None):
# 1. load extra var info # 1. load extra var info
with open(var_info_path, 'rb') as f: with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) extra_var_info = pickle.load(f)
...@@ -464,33 +463,22 @@ def _load_persistable_vars(model_path, ...@@ -464,33 +463,22 @@ def _load_persistable_vars(model_path,
new_var = framework._varbase_creator( new_var = framework._varbase_creator(
name=new_name, persistable=True) name=new_name, persistable=True)
# load separate vars
if separate_params is True:
framework._dygraph_tracer().trace_op(
type='load',
inputs={},
outputs={'Out': new_var},
attrs={'file_path': os.path.join(model_path, name)})
new_var.stop_gradient = extra_var_info[name]['stop_gradient'] new_var.stop_gradient = extra_var_info[name]['stop_gradient']
load_var_dict[new_name] = new_var load_var_dict[new_name] = new_var
load_var_list.append(new_var) load_var_list.append(new_var)
# 3. load all vars # 3. load all vars
if separate_params is False: assert params_filename is not None, "params_filename should not be None."
if params_filename is not None: var_file_path = os.path.join(model_path, params_filename)
var_file_path = os.path.join(model_path, params_filename) if not os.path.exists(var_file_path):
else: if len(extra_var_info) != 0:
var_file_path = os.path.join(model_path, VARIABLE_FILENAME) raise ValueError("The model to be loaded is incomplete.")
if not os.path.exists(var_file_path): else:
if len(extra_var_info) != 0: framework._dygraph_tracer().trace_op(
raise ValueError("The model to be loaded is incomplete.") type='load_combine',
else: inputs={},
framework._dygraph_tracer().trace_op( outputs={'Out': load_var_list},
type='load_combine', attrs={'file_path': var_file_path})
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
return load_var_dict return load_var_dict
...@@ -532,14 +520,13 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -532,14 +520,13 @@ def _construct_program_holders(model_path, model_filename=None):
def _construct_params_and_buffers(model_path, def _construct_params_and_buffers(model_path,
programs, programs,
separate_params=False,
params_filename=None, params_filename=None,
append_suffix=True): append_suffix=True):
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME) var_info_filename = str(params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename)
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(model_path, var_info_path,
programs['forward'], separate_params, programs['forward'], params_filename)
params_filename)
else: else:
var_dict = _load_persistable_vars_by_program( var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename) model_path, programs['forward'], params_filename)
...@@ -700,18 +687,16 @@ class TranslatedLayer(layers.Layer): ...@@ -700,18 +687,16 @@ class TranslatedLayer(layers.Layer):
raise ValueError("There is no directory named '%s'" % model_path) raise ValueError("There is no directory named '%s'" % model_path)
model_filename = None model_filename = None
params_filename = None params_filename = None
separate_params = False
if configs is not None: if configs is not None:
model_filename = configs.model_filename model_filename = configs.model_filename
params_filename = configs.params_filename params_filename = configs.params_filename
separate_params = configs.separate_params
# 1. load program desc & construct _ProgramHolder # 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename) programs = _construct_program_holders(model_path, model_filename)
# 2. load layer parameters & buffers # 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers( persistable_vars = _construct_params_and_buffers(model_path, programs,
model_path, programs, separate_params, params_filename) params_filename)
# 3. construct TranslatedLayer object # 3. construct TranslatedLayer object
translated_layer = TranslatedLayer(programs, persistable_vars) translated_layer = TranslatedLayer(programs, persistable_vars)
......
此差异已折叠。
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
from paddle.fluid.dygraph.jit import SaveLoadConfig from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import TranslatedLayer from paddle.fluid.dygraph.io import TranslatedLayer
...@@ -31,7 +31,7 @@ class StaticModelRunner(object): ...@@ -31,7 +31,7 @@ class StaticModelRunner(object):
""" """
def __new__(cls, model_dir, model_filename=None, params_filename=None): def __new__(cls, model_dir, model_filename=None, params_filename=None):
configs = SaveLoadConfig() configs = _SaveLoadConfig()
if model_filename is not None: if model_filename is not None:
configs.model_filename = model_filename configs.model_filename = model_filename
if params_filename is not None: if params_filename is not None:
......
...@@ -28,11 +28,12 @@ class PredictorTools(object): ...@@ -28,11 +28,12 @@ class PredictorTools(object):
Paddle-Inference predictor Paddle-Inference predictor
''' '''
def __init__(self, model_path, params_file, feeds_var): def __init__(self, model_path, model_file, params_file, feeds_var):
''' '''
__init__ __init__
''' '''
self.model_path = model_path self.model_path = model_path
self.model_file = model_file
self.params_file = params_file self.params_file = params_file
self.feeds_var = feeds_var self.feeds_var = feeds_var
...@@ -43,7 +44,7 @@ class PredictorTools(object): ...@@ -43,7 +44,7 @@ class PredictorTools(object):
''' '''
if os.path.exists(os.path.join(self.model_path, self.params_file)): if os.path.exists(os.path.join(self.model_path, self.params_file)):
config = AnalysisConfig( config = AnalysisConfig(
os.path.join(self.model_path, "__model__"), os.path.join(self.model_path, self.model_file),
os.path.join(self.model_path, self.params_file)) os.path.join(self.model_path, self.params_file))
else: else:
config = AnalysisConfig(os.path.join(self.model_path)) config = AnalysisConfig(os.path.join(self.model_path))
......
...@@ -12,13 +12,15 @@ ...@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import time import time
import unittest import unittest
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from bert_dygraph_model import PretrainModelLayer from bert_dygraph_model import PretrainModelLayer
from bert_utils import get_bert_config, get_feed_data_reader from bert_utils import get_bert_config, get_feed_data_reader
...@@ -31,7 +33,10 @@ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace( ...@@ -31,7 +33,10 @@ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() else fluid.CPUPlace(
SEED = 2020 SEED = 2020
STEP_NUM = 10 STEP_NUM = 10
PRINT_STEP = 2 PRINT_STEP = 2
MODEL_SAVE_PATH = "./bert.inference.model" MODEL_SAVE_DIR = "./inference"
MODEL_SAVE_PREFIX = "./inference/bert"
MODEL_FILENAME = "bert" + INFER_MODEL_SUFFIX
PARAMS_FILENAME = "bert" + INFER_PARAMS_SUFFIX
DY_STATE_DICT_SAVE_PATH = "./bert.dygraph" DY_STATE_DICT_SAVE_PATH = "./bert.dygraph"
...@@ -85,7 +90,7 @@ def train(bert_config, data_reader, to_static): ...@@ -85,7 +90,7 @@ def train(bert_config, data_reader, to_static):
step_idx += 1 step_idx += 1
if step_idx == STEP_NUM: if step_idx == STEP_NUM:
if to_static: if to_static:
fluid.dygraph.jit.save(bert, MODEL_SAVE_PATH) fluid.dygraph.jit.save(bert, MODEL_SAVE_PREFIX)
else: else:
fluid.dygraph.save_dygraph(bert.state_dict(), fluid.dygraph.save_dygraph(bert.state_dict(),
DY_STATE_DICT_SAVE_PATH) DY_STATE_DICT_SAVE_PATH)
...@@ -104,11 +109,15 @@ def train_static(bert_config, data_reader): ...@@ -104,11 +109,15 @@ def train_static(bert_config, data_reader):
def predict_static(data): def predict_static(data):
paddle.enable_static()
exe = fluid.Executor(place) exe = fluid.Executor(place)
# load inference model # load inference model
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
MODEL_SAVE_PATH, executor=exe, params_filename=VARIABLE_FILENAME) MODEL_SAVE_DIR,
executor=exe,
model_filename=MODEL_FILENAME,
params_filename=PARAMS_FILENAME)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed=dict(zip(feed_target_names, data)), feed=dict(zip(feed_target_names, data)),
fetch_list=fetch_targets) fetch_list=fetch_targets)
...@@ -143,7 +152,7 @@ def predict_dygraph(bert_config, data): ...@@ -143,7 +152,7 @@ def predict_dygraph(bert_config, data):
def predict_dygraph_jit(data): def predict_dygraph_jit(data):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
bert = fluid.dygraph.jit.load(MODEL_SAVE_PATH) bert = fluid.dygraph.jit.load(MODEL_SAVE_PREFIX)
bert.eval() bert.eval()
src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = data src_ids, pos_ids, sent_ids, input_mask, mask_label, mask_pos, labels = data
...@@ -155,7 +164,8 @@ def predict_dygraph_jit(data): ...@@ -155,7 +164,8 @@ def predict_dygraph_jit(data):
def predict_analysis_inference(data): def predict_analysis_inference(data):
output = PredictorTools(MODEL_SAVE_PATH, VARIABLE_FILENAME, data) output = PredictorTools(MODEL_SAVE_DIR, MODEL_FILENAME, PARAMS_FILENAME,
data)
out = output() out = output()
return out return out
......
...@@ -21,7 +21,7 @@ import paddle.fluid as fluid ...@@ -21,7 +21,7 @@ import paddle.fluid as fluid
from paddle.fluid import ParamAttr from paddle.fluid import ParamAttr
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import ProgramTranslator from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
...@@ -422,7 +422,10 @@ class Args(object): ...@@ -422,7 +422,10 @@ class Args(object):
prop_boundary_ratio = 0.5 prop_boundary_ratio = 0.5
num_sample = 2 num_sample = 2
num_sample_perbin = 2 num_sample_perbin = 2
infer_dir = './bmn_infer_model' model_save_dir = "./inference"
model_save_prefix = "./inference/bmn"
model_filename = "bmn" + INFER_MODEL_SUFFIX
params_filename = "bmn" + INFER_PARAMS_SUFFIX
dy_param_path = './bmn_dy_param' dy_param_path = './bmn_dy_param'
...@@ -620,7 +623,7 @@ def train_bmn(args, place, to_static): ...@@ -620,7 +623,7 @@ def train_bmn(args, place, to_static):
if batch_id == args.train_batch_num: if batch_id == args.train_batch_num:
if to_static: if to_static:
fluid.dygraph.jit.save(bmn, args.infer_dir) fluid.dygraph.jit.save(bmn, args.model_save_prefix)
else: else:
fluid.dygraph.save_dygraph(bmn.state_dict(), fluid.dygraph.save_dygraph(bmn.state_dict(),
args.dy_param_path) args.dy_param_path)
...@@ -735,13 +738,15 @@ class TestTrain(unittest.TestCase): ...@@ -735,13 +738,15 @@ class TestTrain(unittest.TestCase):
return pred_res return pred_res
def predict_static(self, data): def predict_static(self, data):
paddle.enable_static()
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
# load inference model # load inference model
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
self.args.infer_dir, self.args.model_save_dir,
executor=exe, executor=exe,
params_filename=VARIABLE_FILENAME) model_filename=self.args.model_filename,
params_filename=self.args.params_filename)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data}, feed={feed_target_names[0]: data},
fetch_list=fetch_targets) fetch_list=fetch_targets)
...@@ -750,7 +755,7 @@ class TestTrain(unittest.TestCase): ...@@ -750,7 +755,7 @@ class TestTrain(unittest.TestCase):
def predict_dygraph_jit(self, data): def predict_dygraph_jit(self, data):
with fluid.dygraph.guard(self.place): with fluid.dygraph.guard(self.place):
bmn = fluid.dygraph.jit.load(self.args.infer_dir) bmn = fluid.dygraph.jit.load(self.args.model_save_prefix)
bmn.eval() bmn.eval()
x = to_variable(data) x = to_variable(data)
...@@ -760,7 +765,9 @@ class TestTrain(unittest.TestCase): ...@@ -760,7 +765,9 @@ class TestTrain(unittest.TestCase):
return pred_res return pred_res
def predict_analysis_inference(self, data): def predict_analysis_inference(self, data):
output = PredictorTools(self.args.infer_dir, VARIABLE_FILENAME, [data]) output = PredictorTools(self.args.model_save_dir,
self.args.model_filename,
self.args.params_filename, [data])
out = output() out = output()
return out return out
......
...@@ -26,7 +26,7 @@ import paddle.fluid as fluid ...@@ -26,7 +26,7 @@ import paddle.fluid as fluid
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph import Embedding, Linear, GRUUnit from paddle.fluid.dygraph import Embedding, Linear, GRUUnit
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
...@@ -395,7 +395,10 @@ class Args(object): ...@@ -395,7 +395,10 @@ class Args(object):
base_learning_rate = 0.01 base_learning_rate = 0.01
bigru_num = 2 bigru_num = 2
print_steps = 1 print_steps = 1
model_save_dir = "./lac_model" model_save_dir = "./inference"
model_save_prefix = "./inference/lac"
model_filename = "lac" + INFER_MODEL_SUFFIX
params_filename = "lac" + INFER_PARAMS_SUFFIX
dy_param_path = "./lac_dy_param" dy_param_path = "./lac_dy_param"
...@@ -498,13 +501,11 @@ def do_train(args, to_static): ...@@ -498,13 +501,11 @@ def do_train(args, to_static):
step += 1 step += 1
# save inference model # save inference model
if to_static: if to_static:
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.output_spec = [crf_decode]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
model_path=args.model_save_dir, path=args.model_save_prefix,
input_spec=[words, length], input_spec=[words, length],
configs=configs) output_spec=[crf_decode])
else: else:
fluid.dygraph.save_dygraph(model.state_dict(), args.dy_param_path) fluid.dygraph.save_dygraph(model.state_dict(), args.dy_param_path)
...@@ -573,13 +574,15 @@ class TestLACModel(unittest.TestCase): ...@@ -573,13 +574,15 @@ class TestLACModel(unittest.TestCase):
LAC model contains h_0 created in `__init__` that is necessary for inferring. LAC model contains h_0 created in `__init__` that is necessary for inferring.
Load inference model to test it's ok for prediction. Load inference model to test it's ok for prediction.
""" """
paddle.enable_static()
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
# load inference model # load inference model
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
self.args.model_save_dir, self.args.model_save_dir,
executor=exe, executor=exe,
params_filename=VARIABLE_FILENAME) model_filename=self.args.model_filename,
params_filename=self.args.params_filename)
words, targets, length = batch words, targets, length = batch
pred_res = exe.run( pred_res = exe.run(
...@@ -592,7 +595,7 @@ class TestLACModel(unittest.TestCase): ...@@ -592,7 +595,7 @@ class TestLACModel(unittest.TestCase):
def predict_dygraph_jit(self, batch): def predict_dygraph_jit(self, batch):
words, targets, length = batch words, targets, length = batch
with fluid.dygraph.guard(self.place): with fluid.dygraph.guard(self.place):
model = fluid.dygraph.jit.load(self.args.model_save_dir) model = fluid.dygraph.jit.load(self.args.model_save_prefix)
model.eval() model.eval()
pred_res = model(to_variable(words), to_variable(length)) pred_res = model(to_variable(words), to_variable(length))
...@@ -602,8 +605,9 @@ class TestLACModel(unittest.TestCase): ...@@ -602,8 +605,9 @@ class TestLACModel(unittest.TestCase):
def predict_analysis_inference(self, batch): def predict_analysis_inference(self, batch):
words, targets, length = batch words, targets, length = batch
output = PredictorTools(self.args.model_save_dir, VARIABLE_FILENAME, output = PredictorTools(self.args.model_save_dir,
[words, length]) self.args.model_filename,
self.args.params_filename, [words, length])
out = output() out = output()
return out return out
......
...@@ -25,7 +25,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph ...@@ -25,7 +25,7 @@ from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.dygraph import to_variable from paddle.fluid.dygraph import to_variable
from paddle.fluid.dygraph.nn import Conv2D, Linear, Pool2D from paddle.fluid.dygraph.nn import Conv2D, Linear, Pool2D
from paddle.fluid.optimizer import AdamOptimizer from paddle.fluid.optimizer import AdamOptimizer
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
...@@ -218,34 +218,39 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -218,34 +218,39 @@ class TestMNISTWithToStatic(TestMNIST):
def check_jit_save_load(self, model, inputs, input_spec, to_static, gt_out): def check_jit_save_load(self, model, inputs, input_spec, to_static, gt_out):
if to_static: if to_static:
infer_model_path = "./test_mnist_inference_model_by_jit_save" infer_model_path = "./test_mnist_inference_model_by_jit_save"
configs = fluid.dygraph.jit.SaveLoadConfig() model_save_dir = "./inference"
configs.output_spec = [gt_out] model_save_prefix = "./inference/mnist"
model_filename = "mnist" + INFER_MODEL_SUFFIX
params_filename = "mnist" + INFER_PARAMS_SUFFIX
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
model_path=infer_model_path, path=model_save_prefix,
input_spec=input_spec, input_spec=input_spec,
configs=configs) output_spec=[gt_out])
# load in static mode # load in static mode
static_infer_out = self.jit_load_and_run_inference_static( static_infer_out = self.jit_load_and_run_inference_static(
infer_model_path, inputs) model_save_dir, model_filename, params_filename, inputs)
self.assertTrue(np.allclose(gt_out.numpy(), static_infer_out)) self.assertTrue(np.allclose(gt_out.numpy(), static_infer_out))
# load in dygraph mode # load in dygraph mode
dygraph_infer_out = self.jit_load_and_run_inference_dygraph( dygraph_infer_out = self.jit_load_and_run_inference_dygraph(
infer_model_path, inputs) model_save_prefix, inputs)
self.assertTrue(np.allclose(gt_out.numpy(), dygraph_infer_out)) self.assertTrue(np.allclose(gt_out.numpy(), dygraph_infer_out))
# load in Paddle-Inference # load in Paddle-Inference
predictor_infer_out = self.predictor_load_and_run_inference_analysis( predictor_infer_out = self.predictor_load_and_run_inference_analysis(
infer_model_path, inputs) model_save_dir, model_filename, params_filename, inputs)
self.assertTrue(np.allclose(gt_out.numpy(), predictor_infer_out)) self.assertTrue(np.allclose(gt_out.numpy(), predictor_infer_out))
@switch_to_static_graph @switch_to_static_graph
def jit_load_and_run_inference_static(self, model_path, inputs): def jit_load_and_run_inference_static(self, model_path, model_filename,
params_filename, inputs):
paddle.enable_static()
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
dirname=model_path, dirname=model_path,
executor=exe, executor=exe,
params_filename=VARIABLE_FILENAME) model_filename=model_filename,
params_filename=params_filename)
assert len(inputs) == len(feed_target_names) assert len(inputs) == len(feed_target_names)
results = exe.run(inference_program, results = exe.run(inference_program,
feed=dict(zip(feed_target_names, inputs)), feed=dict(zip(feed_target_names, inputs)),
...@@ -258,8 +263,10 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -258,8 +263,10 @@ class TestMNISTWithToStatic(TestMNIST):
pred = infer_net(inputs[0]) pred = infer_net(inputs[0])
return pred.numpy() return pred.numpy()
def predictor_load_and_run_inference_analysis(self, model_path, inputs): def predictor_load_and_run_inference_analysis(
output = PredictorTools(model_path, VARIABLE_FILENAME, inputs) self, model_path, model_filename, params_filename, inputs):
output = PredictorTools(model_path, model_filename, params_filename,
inputs)
out = output() out = output()
return out return out
......
...@@ -20,7 +20,7 @@ from paddle.fluid.initializer import MSRA ...@@ -20,7 +20,7 @@ from paddle.fluid.initializer import MSRA
from paddle.fluid.param_attr import ParamAttr from paddle.fluid.param_attr import ParamAttr
from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear from paddle.fluid.dygraph.nn import Conv2D, Pool2D, BatchNorm, Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
import unittest import unittest
...@@ -439,7 +439,10 @@ class Args(object): ...@@ -439,7 +439,10 @@ class Args(object):
train_step = 10 train_step = 10
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda( place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
model_save_path = model + ".inference.model" model_save_dir = "./inference"
model_save_prefix = "./inference/" + model
model_filename = model + INFER_MODEL_SUFFIX
params_filename = model + INFER_PARAMS_SUFFIX
dy_state_dict_save_path = model + ".dygraph" dy_state_dict_save_path = model + ".dygraph"
...@@ -504,7 +507,7 @@ def train_mobilenet(args, to_static): ...@@ -504,7 +507,7 @@ def train_mobilenet(args, to_static):
t_last = time.time() t_last = time.time()
if batch_id > args.train_step: if batch_id > args.train_step:
if to_static: if to_static:
fluid.dygraph.jit.save(net, args.model_save_path) fluid.dygraph.jit.save(net, args.model_save_prefix)
else: else:
fluid.dygraph.save_dygraph(net.state_dict(), fluid.dygraph.save_dygraph(net.state_dict(),
args.dy_state_dict_save_path) args.dy_state_dict_save_path)
...@@ -514,11 +517,15 @@ def train_mobilenet(args, to_static): ...@@ -514,11 +517,15 @@ def train_mobilenet(args, to_static):
def predict_static(args, data): def predict_static(args, data):
paddle.enable_static()
exe = fluid.Executor(args.place) exe = fluid.Executor(args.place)
# load inference model # load inference model
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
args.model_save_path, executor=exe, params_filename=VARIABLE_FILENAME) args.model_save_dir,
executor=exe,
model_filename=args.model_filename,
params_filename=args.params_filename)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data}, feed={feed_target_names[0]: data},
...@@ -545,7 +552,7 @@ def predict_dygraph(args, data): ...@@ -545,7 +552,7 @@ def predict_dygraph(args, data):
def predict_dygraph_jit(args, data): def predict_dygraph_jit(args, data):
with fluid.dygraph.guard(args.place): with fluid.dygraph.guard(args.place):
model = fluid.dygraph.jit.load(args.model_save_path) model = fluid.dygraph.jit.load(args.model_save_prefix)
model.eval() model.eval()
pred_res = model(data) pred_res = model(data)
...@@ -554,7 +561,8 @@ def predict_dygraph_jit(args, data): ...@@ -554,7 +561,8 @@ def predict_dygraph_jit(args, data):
def predict_analysis_inference(args, data): def predict_analysis_inference(args, data):
output = PredictorTools(args.model_save_path, VARIABLE_FILENAME, [data]) output = PredictorTools(args.model_save_dir, args.model_filename,
args.params_filename, [data])
out = output() out = output()
return out return out
...@@ -565,7 +573,9 @@ class TestMobileNet(unittest.TestCase): ...@@ -565,7 +573,9 @@ class TestMobileNet(unittest.TestCase):
def train(self, model_name, to_static): def train(self, model_name, to_static):
self.args.model = model_name self.args.model = model_name
self.args.model_save_path = model_name + ".inference.model" self.args.model_save_prefix = "./inference/" + model_name
self.args.model_filename = model_name + INFER_MODEL_SUFFIX
self.args.params_filename = model_name + INFER_PARAMS_SUFFIX
self.args.dy_state_dict_save_path = model_name + ".dygraph" self.args.dy_state_dict_save_path = model_name + ".dygraph"
out = train_mobilenet(self.args, to_static) out = train_mobilenet(self.args, to_static)
return out return out
...@@ -579,7 +589,9 @@ class TestMobileNet(unittest.TestCase): ...@@ -579,7 +589,9 @@ class TestMobileNet(unittest.TestCase):
def assert_same_predict(self, model_name): def assert_same_predict(self, model_name):
self.args.model = model_name self.args.model = model_name
self.args.model_save_path = model_name + ".inference.model" self.args.model_save_prefix = "./inference/" + model_name
self.args.model_filename = model_name + INFER_MODEL_SUFFIX
self.args.params_filename = model_name + INFER_PARAMS_SUFFIX
self.args.dy_state_dict_save_path = model_name + ".dygraph" self.args.dy_state_dict_save_path = model_name + ".dygraph"
local_random = np.random.RandomState(SEED) local_random = np.random.RandomState(SEED)
image = local_random.random_sample([1, 3, 224, 224]).astype('float32') image = local_random.random_sample([1, 3, 224, 224]).astype('float32')
......
...@@ -24,7 +24,7 @@ import paddle ...@@ -24,7 +24,7 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
...@@ -38,7 +38,11 @@ batch_size = 2 ...@@ -38,7 +38,11 @@ batch_size = 2
epoch_num = 1 epoch_num = 1
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \
else fluid.CPUPlace() else fluid.CPUPlace()
MODEL_SAVE_PATH = "./resnet.inference.model"
MODEL_SAVE_DIR = "./inference"
MODEL_SAVE_PREFIX = "./inference/resnet"
MODEL_FILENAME = "resnet" + INFER_MODEL_SUFFIX
PARAMS_FILENAME = "resnet" + INFER_PARAMS_SUFFIX
DY_STATE_DICT_SAVE_PATH = "./resnet.dygraph" DY_STATE_DICT_SAVE_PATH = "./resnet.dygraph"
program_translator = ProgramTranslator() program_translator = ProgramTranslator()
...@@ -261,7 +265,7 @@ def train(to_static): ...@@ -261,7 +265,7 @@ def train(to_static):
total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time)) total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time))
if batch_id == 10: if batch_id == 10:
if to_static: if to_static:
fluid.dygraph.jit.save(resnet, MODEL_SAVE_PATH) fluid.dygraph.jit.save(resnet, MODEL_SAVE_PREFIX)
else: else:
fluid.dygraph.save_dygraph(resnet.state_dict(), fluid.dygraph.save_dygraph(resnet.state_dict(),
DY_STATE_DICT_SAVE_PATH) DY_STATE_DICT_SAVE_PATH)
...@@ -287,10 +291,14 @@ def predict_dygraph(data): ...@@ -287,10 +291,14 @@ def predict_dygraph(data):
def predict_static(data): def predict_static(data):
paddle.enable_static()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
MODEL_SAVE_PATH, executor=exe, params_filename=VARIABLE_FILENAME) MODEL_SAVE_DIR,
executor=exe,
model_filename=MODEL_FILENAME,
params_filename=PARAMS_FILENAME)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data}, feed={feed_target_names[0]: data},
...@@ -301,7 +309,7 @@ def predict_static(data): ...@@ -301,7 +309,7 @@ def predict_static(data):
def predict_dygraph_jit(data): def predict_dygraph_jit(data):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
resnet = fluid.dygraph.jit.load(MODEL_SAVE_PATH) resnet = fluid.dygraph.jit.load(MODEL_SAVE_PREFIX)
resnet.eval() resnet.eval()
pred_res = resnet(data) pred_res = resnet(data)
...@@ -310,7 +318,8 @@ def predict_dygraph_jit(data): ...@@ -310,7 +318,8 @@ def predict_dygraph_jit(data):
def predict_analysis_inference(data): def predict_analysis_inference(data):
output = PredictorTools(MODEL_SAVE_PATH, VARIABLE_FILENAME, [data]) output = PredictorTools(MODEL_SAVE_DIR, MODEL_FILENAME, PARAMS_FILENAME,
[data])
out = output() out = output()
return out return out
......
...@@ -34,7 +34,11 @@ batch_size = 2 ...@@ -34,7 +34,11 @@ batch_size = 2
epoch_num = 1 epoch_num = 1
place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \ place = paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace() else paddle.CPUPlace()
MODEL_SAVE_PATH = "./resnet_v2.inference.model"
MODEL_SAVE_DIR = "./inference"
MODEL_SAVE_PREFIX = "./inference/resnet_v2"
MODEL_FILENAME = "resnet_v2" + paddle.fluid.dygraph.io.INFER_MODEL_SUFFIX
PARAMS_FILENAME = "resnet_v2" + paddle.fluid.dygraph.io.INFER_PARAMS_SUFFIX
DY_STATE_DICT_SAVE_PATH = "./resnet_v2.dygraph" DY_STATE_DICT_SAVE_PATH = "./resnet_v2.dygraph"
program_translator = paddle.jit.ProgramTranslator() program_translator = paddle.jit.ProgramTranslator()
...@@ -255,7 +259,7 @@ def train(to_static): ...@@ -255,7 +259,7 @@ def train(to_static):
total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time)) total_acc1.numpy() / total_sample, total_acc5.numpy() / total_sample, end_time-start_time))
if batch_id == 10: if batch_id == 10:
if to_static: if to_static:
paddle.jit.save(resnet, MODEL_SAVE_PATH) paddle.jit.save(resnet, MODEL_SAVE_PREFIX)
else: else:
paddle.fluid.dygraph.save_dygraph(resnet.state_dict(), paddle.fluid.dygraph.save_dygraph(resnet.state_dict(),
DY_STATE_DICT_SAVE_PATH) DY_STATE_DICT_SAVE_PATH)
...@@ -289,9 +293,10 @@ def predict_static(data): ...@@ -289,9 +293,10 @@ def predict_static(data):
exe = paddle.static.Executor(place) exe = paddle.static.Executor(place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = paddle.static.load_inference_model( fetch_targets] = paddle.static.load_inference_model(
MODEL_SAVE_PATH, MODEL_SAVE_DIR,
executor=exe, executor=exe,
params_filename=paddle.fluid.dygraph.io.VARIABLE_FILENAME) model_filename=MODEL_FILENAME,
params_filename=PARAMS_FILENAME)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data}, feed={feed_target_names[0]: data},
...@@ -302,7 +307,7 @@ def predict_static(data): ...@@ -302,7 +307,7 @@ def predict_static(data):
def predict_dygraph_jit(data): def predict_dygraph_jit(data):
paddle.disable_static(place) paddle.disable_static(place)
resnet = paddle.jit.load(MODEL_SAVE_PATH) resnet = paddle.jit.load(MODEL_SAVE_PREFIX)
resnet.eval() resnet.eval()
pred_res = resnet(data) pred_res = resnet(data)
...@@ -313,8 +318,8 @@ def predict_dygraph_jit(data): ...@@ -313,8 +318,8 @@ def predict_dygraph_jit(data):
def predict_analysis_inference(data): def predict_analysis_inference(data):
output = PredictorTools(MODEL_SAVE_PATH, output = PredictorTools(MODEL_SAVE_DIR, MODEL_FILENAME, PARAMS_FILENAME,
paddle.fluid.dygraph.io.VARIABLE_FILENAME, [data]) [data])
out = output() out = output()
return out return out
......
...@@ -16,14 +16,14 @@ from __future__ import print_function ...@@ -16,14 +16,14 @@ from __future__ import print_function
import os import os
import unittest import unittest
import numpy as np import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.fluid as fluid
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from from paddle.fluid.dygraph.dygraph_to_static.partial_program import partial_program_from
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
SEED = 2020 SEED = 2020
...@@ -66,14 +66,13 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -66,14 +66,13 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
adam.minimize(loss) adam.minimize(loss)
layer.clear_gradients() layer.clear_gradients()
# test for saving model in dygraph.guard # test for saving model in dygraph.guard
infer_model_dir = "./test_dy2stat_save_inference_model_in_guard" infer_model_prefix = "./test_dy2stat_inference_in_guard/model"
configs = fluid.dygraph.jit.SaveLoadConfig() infer_model_dir = "./test_dy2stat_inference_in_guard"
configs.output_spec = [pred]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=layer, layer=layer,
model_path=infer_model_dir, path=infer_model_prefix,
input_spec=[x], input_spec=[x],
configs=configs) output_spec=[pred])
# Check the correctness of the inference # Check the correctness of the inference
dygraph_out, _ = layer(x) dygraph_out, _ = layer(x)
self.check_save_inference_model(layer, [x_data], dygraph_out.numpy()) self.check_save_inference_model(layer, [x_data], dygraph_out.numpy())
...@@ -91,30 +90,30 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -91,30 +90,30 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
expected_persistable_vars = set([p.name for p in model.parameters()]) expected_persistable_vars = set([p.name for p in model.parameters()])
infer_model_dir = "./test_dy2stat_save_inference_model" infer_model_prefix = "./test_dy2stat_inference/model"
configs = fluid.dygraph.jit.SaveLoadConfig() infer_model_dir = "./test_dy2stat_inference"
if fetch is not None: model_filename = "model" + INFER_MODEL_SUFFIX
configs.output_spec = fetch params_filename = "model" + INFER_PARAMS_SUFFIX
configs.separate_params = True
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
model_path=infer_model_dir, path=infer_model_prefix,
input_spec=feed if feed else None, input_spec=feed if feed else None,
configs=configs) output_spec=fetch if fetch else None)
saved_var_names = set([
filename for filename in os.listdir(infer_model_dir)
if filename != '__model__' and filename != EXTRA_VAR_INFO_FILENAME
])
self.assertEqual(saved_var_names, expected_persistable_vars)
# Check the correctness of the inference # Check the correctness of the inference
infer_out = self.load_and_run_inference(infer_model_dir, inputs) infer_out = self.load_and_run_inference(infer_model_dir, model_filename,
params_filename, inputs)
self.assertTrue(np.allclose(gt_out, infer_out)) self.assertTrue(np.allclose(gt_out, infer_out))
def load_and_run_inference(self, model_path, inputs): def load_and_run_inference(self, model_path, model_filename,
params_filename, inputs):
paddle.enable_static()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
dirname=model_path, executor=exe) dirname=model_path,
executor=exe,
model_filename=model_filename,
params_filename=params_filename)
results = exe.run(inference_program, results = exe.run(inference_program,
feed=dict(zip(feed_target_names, inputs)), feed=dict(zip(feed_target_names, inputs)),
fetch_list=fetch_targets) fetch_list=fetch_targets)
......
...@@ -24,7 +24,7 @@ from paddle.fluid.dygraph.base import to_variable ...@@ -24,7 +24,7 @@ from paddle.fluid.dygraph.base import to_variable
from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D from paddle.fluid.dygraph.nn import BatchNorm, Conv2D, Linear, Pool2D
from paddle.fluid.dygraph import declarative from paddle.fluid.dygraph import declarative
from paddle.fluid.dygraph import ProgramTranslator from paddle.fluid.dygraph import ProgramTranslator
from paddle.fluid.dygraph.io import VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX
from predictor_utils import PredictorTools from predictor_utils import PredictorTools
...@@ -35,7 +35,10 @@ BATCH_SIZE = 8 ...@@ -35,7 +35,10 @@ BATCH_SIZE = 8
EPOCH_NUM = 1 EPOCH_NUM = 1
PRINT_STEP = 2 PRINT_STEP = 2
STEP_NUM = 10 STEP_NUM = 10
MODEL_SAVE_PATH = "./se_resnet.inference.model" MODEL_SAVE_DIR = "./inference"
MODEL_SAVE_PREFIX = "./inference/se_resnet"
MODEL_FILENAME = "se_resnet" + INFER_MODEL_SUFFIX
PARAMS_FILENAME = "se_resnet" + INFER_PARAMS_SUFFIX
DY_STATE_DICT_SAVE_PATH = "./se_resnet.dygraph" DY_STATE_DICT_SAVE_PATH = "./se_resnet.dygraph"
place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \ place = fluid.CUDAPlace(0) if fluid.is_compiled_with_cuda() \
...@@ -383,10 +386,10 @@ def train(train_reader, to_static): ...@@ -383,10 +386,10 @@ def train(train_reader, to_static):
step_idx += 1 step_idx += 1
if step_idx == STEP_NUM: if step_idx == STEP_NUM:
if to_static: if to_static:
configs = fluid.dygraph.jit.SaveLoadConfig() fluid.dygraph.jit.save(
configs.output_spec = [pred] se_resnext,
fluid.dygraph.jit.save(se_resnext, MODEL_SAVE_PATH, MODEL_SAVE_PREFIX, [img],
[img], configs) output_spec=[pred])
else: else:
fluid.dygraph.save_dygraph(se_resnext.state_dict(), fluid.dygraph.save_dygraph(se_resnext.state_dict(),
DY_STATE_DICT_SAVE_PATH) DY_STATE_DICT_SAVE_PATH)
...@@ -414,10 +417,14 @@ def predict_dygraph(data): ...@@ -414,10 +417,14 @@ def predict_dygraph(data):
def predict_static(data): def predict_static(data):
paddle.enable_static()
exe = fluid.Executor(place) exe = fluid.Executor(place)
[inference_program, feed_target_names, [inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model( fetch_targets] = fluid.io.load_inference_model(
MODEL_SAVE_PATH, executor=exe, params_filename=VARIABLE_FILENAME) MODEL_SAVE_DIR,
executor=exe,
model_filename=MODEL_FILENAME,
params_filename=PARAMS_FILENAME)
pred_res = exe.run(inference_program, pred_res = exe.run(inference_program,
feed={feed_target_names[0]: data}, feed={feed_target_names[0]: data},
...@@ -428,7 +435,7 @@ def predict_static(data): ...@@ -428,7 +435,7 @@ def predict_static(data):
def predict_dygraph_jit(data): def predict_dygraph_jit(data):
with fluid.dygraph.guard(place): with fluid.dygraph.guard(place):
se_resnext = fluid.dygraph.jit.load(MODEL_SAVE_PATH) se_resnext = fluid.dygraph.jit.load(MODEL_SAVE_PREFIX)
se_resnext.eval() se_resnext.eval()
pred_res = se_resnext(data) pred_res = se_resnext(data)
...@@ -437,7 +444,8 @@ def predict_dygraph_jit(data): ...@@ -437,7 +444,8 @@ def predict_dygraph_jit(data):
def predict_analysis_inference(data): def predict_analysis_inference(data):
output = PredictorTools(MODEL_SAVE_PATH, VARIABLE_FILENAME, [data]) output = PredictorTools(MODEL_SAVE_DIR, MODEL_FILENAME, PARAMS_FILENAME,
[data])
out = output() out = output()
return out return out
......
...@@ -32,6 +32,7 @@ STEP_NUM = 10 ...@@ -32,6 +32,7 @@ STEP_NUM = 10
def train_static(args, batch_generator): def train_static(args, batch_generator):
paddle.enable_static()
paddle.manual_seed(SEED) paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
train_prog = fluid.Program() train_prog = fluid.Program()
......
...@@ -277,7 +277,8 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -277,7 +277,8 @@ def load_dygraph(model_path, keep_name_table=False):
To load python2 saved models in python3. To load python2 saved models in python3.
""" """
try: try:
para_dict, opti_dict = fluid.load_dygraph(model_path, keep_name_table) para_dict, opti_dict = fluid.load_dygraph(
model_path, keep_name_table=keep_name_table)
return para_dict, opti_dict return para_dict, opti_dict
except UnicodeDecodeError: except UnicodeDecodeError:
warnings.warn( warnings.warn(
...@@ -287,7 +288,7 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -287,7 +288,7 @@ def load_dygraph(model_path, keep_name_table=False):
if six.PY3: if six.PY3:
load_bak = pickle.load load_bak = pickle.load
pickle.load = partial(load_bak, encoding="latin1") pickle.load = partial(load_bak, encoding="latin1")
para_dict, opti_dict = fluid.load_dygraph(model_path, para_dict, opti_dict = fluid.load_dygraph(
keep_name_table) model_path, keep_name_table=keep_name_table)
pickle.load = load_bak pickle.load = load_bak
return para_dict, opti_dict return para_dict, opti_dict
...@@ -43,15 +43,14 @@ class TestDirectory(unittest.TestCase): ...@@ -43,15 +43,14 @@ class TestDirectory(unittest.TestCase):
'paddle.distributed.prepare_context', 'paddle.DataParallel', 'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static', 'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer', 'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.SaveLoadConfig', 'paddle.jit.save', 'paddle.jit.load', 'paddle.NoamDecay',
'paddle.NoamDecay', 'paddle.PiecewiseDecay', 'paddle.PiecewiseDecay', 'paddle.NaturalExpDecay',
'paddle.NaturalExpDecay', 'paddle.ExponentialDecay', 'paddle.ExponentialDecay', 'paddle.InverseTimeDecay',
'paddle.InverseTimeDecay', 'paddle.PolynomialDecay', 'paddle.PolynomialDecay', 'paddle.CosineDecay',
'paddle.CosineDecay', 'paddle.static.Executor', 'paddle.static.Executor', 'paddle.static.global_scope',
'paddle.static.global_scope', 'paddle.static.scope_guard', 'paddle.static.scope_guard', 'paddle.static.append_backward',
'paddle.static.append_backward', 'paddle.static.gradients', 'paddle.static.gradients', 'paddle.static.BuildStrategy',
'paddle.static.BuildStrategy', 'paddle.static.CompiledProgram', 'paddle.static.CompiledProgram', 'paddle.static.ExecutionStrategy',
'paddle.static.ExecutionStrategy',
'paddle.static.default_main_program', 'paddle.static.default_main_program',
'paddle.static.default_startup_program', 'paddle.static.Program', 'paddle.static.default_startup_program', 'paddle.static.Program',
'paddle.static.name_scope', 'paddle.static.program_guard', 'paddle.static.name_scope', 'paddle.static.program_guard',
...@@ -104,9 +103,7 @@ class TestDirectory(unittest.TestCase): ...@@ -104,9 +103,7 @@ class TestDirectory(unittest.TestCase):
'paddle.imperative.TracedLayer', 'paddle.imperative.declarative', 'paddle.imperative.TracedLayer', 'paddle.imperative.declarative',
'paddle.imperative.ProgramTranslator', 'paddle.imperative.ProgramTranslator',
'paddle.imperative.TranslatedLayer', 'paddle.imperative.jit.save', 'paddle.imperative.TranslatedLayer', 'paddle.imperative.jit.save',
'paddle.imperative.jit.load', 'paddle.imperative.jit.load', 'paddle.imperative.NoamDecay'
'paddle.imperative.jit.SaveLoadConfig',
'paddle.imperative.NoamDecay'
'paddle.imperative.PiecewiseDecay', 'paddle.imperative.PiecewiseDecay',
'paddle.imperative.NaturalExpDecay', 'paddle.imperative.NaturalExpDecay',
'paddle.imperative.ExponentialDecay', 'paddle.imperative.ExponentialDecay',
......
...@@ -917,11 +917,6 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -917,11 +917,6 @@ class TestDygraphPtbRnn(unittest.TestCase):
state_dict = emb.state_dict() state_dict = emb.state_dict()
fluid.save_dygraph(state_dict, os.path.join('saved_dy', 'emb_dy')) fluid.save_dygraph(state_dict, os.path.join('saved_dy', 'emb_dy'))
para_state_dict, opti_state_dict = fluid.load_dygraph(
os.path.join('saved_dy', 'emb_dy'), True)
self.assertTrue(para_state_dict != None)
self.assertTrue(opti_state_dict == None)
para_state_dict, opti_state_dict = fluid.load_dygraph( para_state_dict, opti_state_dict = fluid.load_dygraph(
os.path.join('saved_dy', 'emb_dy'), keep_name_table=True) os.path.join('saved_dy', 'emb_dy'), keep_name_table=True)
self.assertTrue(para_state_dict != None) self.assertTrue(para_state_dict != None)
......
...@@ -23,7 +23,7 @@ from paddle.static import InputSpec ...@@ -23,7 +23,7 @@ from paddle.static import InputSpec
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.dygraph import Linear from paddle.fluid.dygraph import Linear
from paddle.fluid.dygraph import declarative, ProgramTranslator from paddle.fluid.dygraph import declarative, ProgramTranslator
from paddle.fluid.dygraph.io import EXTRA_VAR_INFO_FILENAME, VARIABLE_FILENAME from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
BATCH_SIZE = 32 BATCH_SIZE = 32
BATCH_NUM = 10 BATCH_NUM = 10
...@@ -127,8 +127,8 @@ class MultiLoadingLinearNet(fluid.dygraph.Layer): ...@@ -127,8 +127,8 @@ class MultiLoadingLinearNet(fluid.dygraph.Layer):
def __init__(self, size, model_path): def __init__(self, size, model_path):
super(MultiLoadingLinearNet, self).__init__() super(MultiLoadingLinearNet, self).__init__()
self._linear = Linear(size, size) self._linear = Linear(size, size)
self._load_linear1 = fluid.dygraph.jit.load(model_path) self._load_linear1 = paddle.jit.load(model_path)
self._load_linear2 = fluid.dygraph.jit.load(model_path) self._load_linear2 = paddle.jit.load(model_path)
@declarative @declarative
def forward(self, x): def forward(self, x):
...@@ -218,23 +218,20 @@ def train_with_label(layer, input_size=784, label_size=1): ...@@ -218,23 +218,20 @@ def train_with_label(layer, input_size=784, label_size=1):
class TestJitSaveLoad(unittest.TestCase): class TestJitSaveLoad(unittest.TestCase):
def setUp(self): def setUp(self):
self.model_path = "model.test_jit_save_load" self.model_path = "test_jit_save_load/model"
# enable dygraph mode # enable dygraph mode
fluid.enable_dygraph() fluid.enable_dygraph()
# config seed # config seed
paddle.manual_seed(SEED) paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def train_and_save_model(self, model_path=None, configs=None): def train_and_save_model(self, model_path=None):
layer = LinearNet(784, 1) layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer) example_inputs, layer, _ = train(layer)
final_model_path = model_path if model_path else self.model_path final_model_path = model_path if model_path else self.model_path
orig_input_types = [type(x) for x in example_inputs] orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save( paddle.jit.save(
layer=layer, layer=layer, path=final_model_path, input_spec=example_inputs)
model_path=final_model_path,
input_spec=example_inputs,
configs=configs)
new_input_types = [type(x) for x in example_inputs] new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types) self.assertEqual(orig_input_types, new_input_types)
return layer return layer
...@@ -243,13 +240,10 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -243,13 +240,10 @@ class TestJitSaveLoad(unittest.TestCase):
# train and save model # train and save model
train_layer = self.train_and_save_model() train_layer = self.train_and_save_model()
# load model # load model
program_translator = ProgramTranslator() loaded_layer = paddle.jit.load(self.model_path)
program_translator.enable(False)
loaded_layer = fluid.dygraph.jit.load(self.model_path)
self.load_and_inference(train_layer, loaded_layer) self.load_and_inference(train_layer, loaded_layer)
self.load_dygraph_state_dict(train_layer) self.load_dygraph_state_dict(train_layer)
self.load_and_finetune(train_layer, loaded_layer) self.load_and_finetune(train_layer, loaded_layer)
program_translator.enable(True)
def load_and_inference(self, train_layer, infer_layer): def load_and_inference(self, train_layer, infer_layer):
train_layer.eval() train_layer.eval()
...@@ -274,7 +268,7 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -274,7 +268,7 @@ class TestJitSaveLoad(unittest.TestCase):
# construct new model # construct new model
new_layer = LinearNet(784, 1) new_layer = LinearNet(784, 1)
orig_state_dict = new_layer.state_dict() orig_state_dict = new_layer.state_dict()
load_state_dict, _ = fluid.dygraph.load_dygraph(self.model_path) load_state_dict = paddle.load(self.model_path)
for structured_name in orig_state_dict: for structured_name in orig_state_dict:
self.assertTrue(structured_name in load_state_dict) self.assertTrue(structured_name in load_state_dict)
new_layer.set_state_dict(load_state_dict) new_layer.set_state_dict(load_state_dict)
...@@ -286,20 +280,24 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -286,20 +280,24 @@ class TestJitSaveLoad(unittest.TestCase):
np.array_equal(train_layer(x).numpy(), new_layer(x).numpy())) np.array_equal(train_layer(x).numpy(), new_layer(x).numpy()))
def test_load_dygraph_no_path(self): def test_load_dygraph_no_path(self):
model_path = "model.test_jit_save_load.no_path" model_path = "test_jit_save_load.no_path/model_path"
new_layer = LinearNet(784, 1)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path) model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_jit_load_model_incomplete(self): def test_jit_load_model_incomplete(self):
model_path = "model.test_jit_save_load.remove_variables" model_path = "test_jit_save_load.remove_variables/model"
self.train_and_save_model(model_path=model_path) self.train_and_save_model(model_path)
# remove `__variables__` # remove `.pdiparams`
var_path = os.path.join(model_path, VARIABLE_FILENAME) var_path = model_path + INFER_PARAMS_SUFFIX
os.remove(var_path) os.remove(var_path)
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.jit.load(model_path) paddle.jit.load(model_path)
def test_jit_load_no_path(self):
path = "test_jit_save_load.no_path/model_path"
with self.assertRaises(ValueError):
loaded_layer = paddle.jit.load(path)
class TestSaveLoadWithInputSpec(unittest.TestCase): class TestSaveLoadWithInputSpec(unittest.TestCase):
def setUp(self): def setUp(self):
...@@ -313,8 +311,7 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -313,8 +311,7 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
net.forward, input_spec=[InputSpec( net.forward, input_spec=[InputSpec(
[None, 8], name='x')]) [None, 8], name='x')])
model_path = "model.input_spec.output_spec" model_path = "input_spec.output_spec/model"
configs = fluid.dygraph.jit.SaveLoadConfig()
# check inputs and outputs # check inputs and outputs
self.assertTrue(len(net.forward.inputs) == 1) self.assertTrue(len(net.forward.inputs) == 1)
input_x = net.forward.inputs[0] input_x = net.forward.inputs[0]
...@@ -322,11 +319,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -322,11 +319,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
self.assertTrue(input_x.name == 'x') self.assertTrue(input_x.name == 'x')
# 1. prune loss # 1. prune loss
configs.output_spec = net.forward.outputs[:1] output_spec = net.forward.outputs[:1]
fluid.dygraph.jit.save(net, model_path, configs=configs) paddle.jit.save(net, model_path, output_spec=output_spec)
# 2. load to infer # 2. load to infer
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = paddle.jit.load(model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
pred = infer_layer(x) pred = infer_layer(x)
...@@ -334,8 +331,7 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -334,8 +331,7 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
def test_multi_in_out(self): def test_multi_in_out(self):
net = LinearNetMultiInput(8, 8) net = LinearNetMultiInput(8, 8)
model_path = "model.multi_inout.output_spec1" model_path = "multi_inout.output_spec1/model"
configs = fluid.dygraph.jit.SaveLoadConfig()
# 1. check inputs and outputs # 1. check inputs and outputs
self.assertTrue(len(net.forward.inputs) == 2) self.assertTrue(len(net.forward.inputs) == 2)
input_x = net.forward.inputs[0] input_x = net.forward.inputs[0]
...@@ -344,11 +340,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -344,11 +340,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
self.assertTrue(input_y.shape == (-1, 8)) self.assertTrue(input_y.shape == (-1, 8))
# 2. prune loss # 2. prune loss
configs.output_spec = net.forward.outputs[:2] output_spec = net.forward.outputs[:2]
fluid.dygraph.jit.save(net, model_path, configs=configs) paddle.jit.save(net, model_path, output_spec=output_spec)
# 3. load to infer # 3. load to infer
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = paddle.jit.load(model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
y = fluid.dygraph.to_variable( y = fluid.dygraph.to_variable(
...@@ -357,11 +353,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -357,11 +353,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
pred_x, pred_y = infer_layer(x, y) pred_x, pred_y = infer_layer(x, y)
# 1. prune y and loss # 1. prune y and loss
model_path = "model.multi_inout.output_spec2" model_path = "multi_inout.output_spec2/model"
configs.output_spec = net.forward.outputs[:1] output_spec = net.forward.outputs[:1]
fluid.dygraph.jit.save(net, model_path, [input_x], configs) paddle.jit.save(net, model_path, [input_x], output_spec=output_spec)
# 2. load again # 2. load again
infer_layer2 = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer2 = paddle.jit.load(model_path)
# 3. predict # 3. predict
pred_xx = infer_layer2(x) pred_xx = infer_layer2(x)
...@@ -377,44 +373,6 @@ class TestJitSaveLoadConfig(unittest.TestCase): ...@@ -377,44 +373,6 @@ class TestJitSaveLoadConfig(unittest.TestCase):
paddle.manual_seed(SEED) paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def basic_save_load(self, layer, model_path, configs):
# 1. train & save
example_inputs, train_layer, _ = train(layer)
fluid.dygraph.jit.save(
layer=train_layer,
model_path=model_path,
input_spec=example_inputs,
configs=configs)
# 2. load
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs)
train_layer.eval()
# 3. inference & compare
x = fluid.dygraph.to_variable(
np.random.random((1, 784)).astype('float32'))
self.assertTrue(
np.array_equal(train_layer(x).numpy(), infer_layer(x).numpy()))
def test_model_filename(self):
layer = LinearNet(784, 1)
model_path = "model.save_load_config.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.model_filename = "__simplenet__"
self.basic_save_load(layer, model_path, configs)
def test_params_filename(self):
layer = LinearNet(784, 1)
model_path = "model.save_load_config.params_filename"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.params_filename = "__params__"
self.basic_save_load(layer, model_path, configs)
def test_separate_params(self):
layer = LinearNet(784, 1)
model_path = "model.save_load_config.separate_params"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.separate_params = True
self.basic_save_load(layer, model_path, configs)
def test_output_spec(self): def test_output_spec(self):
train_layer = LinearNetReturnLoss(8, 8) train_layer = LinearNetReturnLoss(8, 8)
adam = fluid.optimizer.AdamOptimizer( adam = fluid.optimizer.AdamOptimizer(
...@@ -427,27 +385,47 @@ class TestJitSaveLoadConfig(unittest.TestCase): ...@@ -427,27 +385,47 @@ class TestJitSaveLoadConfig(unittest.TestCase):
adam.minimize(loss) adam.minimize(loss)
train_layer.clear_gradients() train_layer.clear_gradients()
model_path = "model.save_load_config.output_spec" model_path = "save_load_config.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig() output_spec = [out]
configs.output_spec = [out] paddle.jit.save(
fluid.dygraph.jit.save(
layer=train_layer, layer=train_layer,
model_path=model_path, path=model_path,
input_spec=[x], input_spec=[x],
configs=configs) output_spec=output_spec)
train_layer.eval() train_layer.eval()
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = paddle.jit.load(model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
self.assertTrue( self.assertTrue(
np.array_equal(train_layer(x)[0].numpy(), infer_layer(x).numpy())) np.array_equal(train_layer(x)[0].numpy(), infer_layer(x).numpy()))
def test_save_no_support_config_error(self):
layer = LinearNet(784, 1)
path = "no_support_config_test"
with self.assertRaises(ValueError):
paddle.jit.save(layer=layer, path=path, model_filename="")
def test_load_empty_model_filename_error(self):
path = "error_model_filename_test"
with self.assertRaises(ValueError):
paddle.jit.load(path, model_filename="")
def test_load_empty_params_filename_error(self):
path = "error_params_filename_test"
with self.assertRaises(ValueError):
paddle.jit.load(path, params_filename="")
def test_load_with_no_support_config(self):
path = "no_support_config_test"
with self.assertRaises(ValueError):
paddle.jit.load(path, separate_params=True)
class TestJitMultipleLoading(unittest.TestCase): class TestJitMultipleLoading(unittest.TestCase):
def setUp(self): def setUp(self):
self.linear_size = 4 self.linear_size = 4
self.model_path = "model.jit_multi_load" self.model_path = "jit_multi_load/model"
# enable dygraph mode # enable dygraph mode
fluid.enable_dygraph() fluid.enable_dygraph()
# config seed # config seed
...@@ -459,8 +437,8 @@ class TestJitMultipleLoading(unittest.TestCase): ...@@ -459,8 +437,8 @@ class TestJitMultipleLoading(unittest.TestCase):
def train_and_save_orig_model(self): def train_and_save_orig_model(self):
layer = LinearNet(self.linear_size, self.linear_size) layer = LinearNet(self.linear_size, self.linear_size)
example_inputs, layer, _ = train(layer, self.linear_size, 1) example_inputs, layer, _ = train(layer, self.linear_size, 1)
fluid.dygraph.jit.save( paddle.jit.save(
layer=layer, model_path=self.model_path, input_spec=example_inputs) layer=layer, path=self.model_path, input_spec=example_inputs)
def test_load_model_retransform_inference(self): def test_load_model_retransform_inference(self):
multi_loaded_layer = MultiLoadingLinearNet(self.linear_size, multi_loaded_layer = MultiLoadingLinearNet(self.linear_size,
...@@ -475,7 +453,7 @@ class TestJitMultipleLoading(unittest.TestCase): ...@@ -475,7 +453,7 @@ class TestJitMultipleLoading(unittest.TestCase):
class TestJitPruneModelAndLoad(unittest.TestCase): class TestJitPruneModelAndLoad(unittest.TestCase):
def setUp(self): def setUp(self):
self.linear_size = 4 self.linear_size = 4
self.model_path = "model.jit_prune_model_and_load" self.model_path = "jit_prune_model_and_load/model"
# enable dygraph mode # enable dygraph mode
fluid.enable_dygraph() fluid.enable_dygraph()
# config seed # config seed
...@@ -494,13 +472,12 @@ class TestJitPruneModelAndLoad(unittest.TestCase): ...@@ -494,13 +472,12 @@ class TestJitPruneModelAndLoad(unittest.TestCase):
adam.minimize(loss) adam.minimize(loss)
train_layer.clear_gradients() train_layer.clear_gradients()
configs = fluid.dygraph.jit.SaveLoadConfig() output_spec = [hidden]
configs.output_spec = [hidden] paddle.jit.save(
fluid.dygraph.jit.save(
layer=train_layer, layer=train_layer,
model_path=self.model_path, path=self.model_path,
input_spec=[x], input_spec=[x],
configs=configs) output_spec=output_spec)
return train_layer return train_layer
...@@ -508,7 +485,7 @@ class TestJitPruneModelAndLoad(unittest.TestCase): ...@@ -508,7 +485,7 @@ class TestJitPruneModelAndLoad(unittest.TestCase):
train_layer = self.train_and_save() train_layer = self.train_and_save()
train_layer.eval() train_layer.eval()
infer_layer = fluid.dygraph.jit.load(self.model_path) infer_layer = paddle.jit.load(self.model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
...@@ -519,7 +496,7 @@ class TestJitPruneModelAndLoad(unittest.TestCase): ...@@ -519,7 +496,7 @@ class TestJitPruneModelAndLoad(unittest.TestCase):
self.train_and_save() self.train_and_save()
# chage extra var info # chage extra var info
var_info_path = os.path.join(self.model_path, EXTRA_VAR_INFO_FILENAME) var_info_path = self.model_path + INFER_PARAMS_INFO_SUFFIX
with open(var_info_path, 'rb') as f: with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) extra_var_info = pickle.load(f)
extra_var_info.clear() extra_var_info.clear()
...@@ -527,7 +504,7 @@ class TestJitPruneModelAndLoad(unittest.TestCase): ...@@ -527,7 +504,7 @@ class TestJitPruneModelAndLoad(unittest.TestCase):
pickle.dump(extra_var_info, f, protocol=2) pickle.dump(extra_var_info, f, protocol=2)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
fluid.dygraph.jit.load(self.model_path) paddle.jit.load(self.model_path)
class TestJitSaveMultiCases(unittest.TestCase): class TestJitSaveMultiCases(unittest.TestCase):
...@@ -561,7 +538,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -561,7 +538,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
train(layer) train(layer)
model_path = "test_no_prune_to_static_after_train" model_path = "test_no_prune_to_static_after_train/model"
paddle.jit.save(layer, model_path) paddle.jit.save(layer, model_path)
self.verify_inference_correctness(layer, model_path) self.verify_inference_correctness(layer, model_path)
...@@ -569,7 +546,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -569,7 +546,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
def test_no_prune_to_static_no_train(self): def test_no_prune_to_static_no_train(self):
layer = LinearNetWithInputSpec(784, 1) layer = LinearNetWithInputSpec(784, 1)
model_path = "test_no_prune_to_static_no_train" model_path = "test_no_prune_to_static_no_train/model"
paddle.jit.save(layer, model_path) paddle.jit.save(layer, model_path)
self.verify_inference_correctness(layer, model_path) self.verify_inference_correctness(layer, model_path)
...@@ -579,7 +556,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -579,7 +556,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
train(layer) train(layer)
model_path = "test_no_prune_no_to_static_after_train" model_path = "test_no_prune_no_to_static_after_train/model"
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -593,16 +570,15 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -593,16 +570,15 @@ class TestJitSaveMultiCases(unittest.TestCase):
example_inputs, _, _ = train(layer) example_inputs, _, _ = train(layer)
model_path = "test_no_prune_no_to_static_after_train_with_examples" model_path = "test_no_prune_no_to_static_after_train_with_examples/model"
fluid.dygraph.jit.save( paddle.jit.save(layer=layer, path=model_path, input_spec=example_inputs)
layer=layer, model_path=model_path, input_spec=example_inputs)
self.verify_inference_correctness(layer, model_path) self.verify_inference_correctness(layer, model_path)
def test_no_prune_no_to_static_no_train(self): def test_no_prune_no_to_static_no_train(self):
layer = LinearNetNotDeclarative(784, 1) layer = LinearNetNotDeclarative(784, 1)
model_path = "test_no_prune_no_to_static_no_train" model_path = "test_no_prune_no_to_static_no_train/model"
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -616,9 +592,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -616,9 +592,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
out = train_with_label(layer) out = train_with_label(layer)
model_path = "test_prune_to_static_after_train" model_path = "test_prune_to_static_after_train/model"
configs = paddle.SaveLoadConfig()
configs.output_spec = [out]
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -626,18 +600,17 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -626,18 +600,17 @@ class TestJitSaveMultiCases(unittest.TestCase):
InputSpec( InputSpec(
shape=[None, 784], dtype='float32', name="image") shape=[None, 784], dtype='float32', name="image")
], ],
configs=configs) output_spec=[out])
self.verify_inference_correctness(layer, model_path, True) self.verify_inference_correctness(layer, model_path, True)
def test_prune_to_static_no_train(self): def test_prune_to_static_no_train(self):
layer = LinerNetWithLabel(784, 1) layer = LinerNetWithLabel(784, 1)
model_path = "test_prune_to_static_no_train" model_path = "test_prune_to_static_no_train/model"
configs = paddle.SaveLoadConfig()
# TODO: no train, cannot get output_spec var here # TODO: no train, cannot get output_spec var here
# now only can use index # now only can use index
configs.output_spec = layer.forward.outputs[:1] output_spec = layer.forward.outputs[:1]
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -645,7 +618,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -645,7 +618,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
InputSpec( InputSpec(
shape=[None, 784], dtype='float32', name="image") shape=[None, 784], dtype='float32', name="image")
], ],
configs=configs) output_spec=output_spec)
self.verify_inference_correctness(layer, model_path, True) self.verify_inference_correctness(layer, model_path, True)
...@@ -654,7 +627,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -654,7 +627,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
train(layer) train(layer)
model_path = "test_no_prune_input_spec_name_warning" model_path = "test_no_prune_input_spec_name_warning/model"
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -675,18 +648,16 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -675,18 +648,16 @@ class TestJitSaveMultiCases(unittest.TestCase):
train(layer) train(layer)
model_path = "test_not_prune_output_spec_name_warning" model_path = "test_not_prune_output_spec_name_warning/model"
configs = paddle.SaveLoadConfig()
out = paddle.to_tensor(np.random.random((1, 1)).astype('float')) out = paddle.to_tensor(np.random.random((1, 1)).astype('float'))
configs.output_spec = [out] paddle.jit.save(layer, model_path, output_spec=[out])
paddle.jit.save(layer, model_path, configs=configs)
self.verify_inference_correctness(layer, model_path) self.verify_inference_correctness(layer, model_path)
def test_prune_input_spec_name_error(self): def test_prune_input_spec_name_error(self):
layer = LinerNetWithLabel(784, 1) layer = LinerNetWithLabel(784, 1)
model_path = "test_prune_input_spec_name_error" model_path = "test_prune_input_spec_name_error/model"
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.jit.save( paddle.jit.save(
layer, layer,
...@@ -707,10 +678,8 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -707,10 +678,8 @@ class TestJitSaveMultiCases(unittest.TestCase):
train_with_label(layer) train_with_label(layer)
model_path = "test_prune_to_static_after_train" model_path = "test_prune_to_static_after_train/model"
configs = paddle.SaveLoadConfig()
out = paddle.to_tensor(np.random.random((1, 1)).astype('float')) out = paddle.to_tensor(np.random.random((1, 1)).astype('float'))
configs.output_spec = [out]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.jit.save( paddle.jit.save(
layer, layer,
...@@ -719,12 +688,12 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -719,12 +688,12 @@ class TestJitSaveMultiCases(unittest.TestCase):
InputSpec( InputSpec(
shape=[None, 784], dtype='float32', name="image") shape=[None, 784], dtype='float32', name="image")
], ],
configs=configs) output_spec=[out])
class TestJitSaveLoadEmptyLayer(unittest.TestCase): class TestJitSaveLoadEmptyLayer(unittest.TestCase):
def setUp(self): def setUp(self):
self.model_path = "model.jit_save_load_empty_layer" self.model_path = "jit_save_load_empty_layer/model"
# enable dygraph mode # enable dygraph mode
paddle.disable_static() paddle.disable_static()
...@@ -740,7 +709,7 @@ class TestJitSaveLoadEmptyLayer(unittest.TestCase): ...@@ -740,7 +709,7 @@ class TestJitSaveLoadEmptyLayer(unittest.TestCase):
class TestJitSaveLoadNoParamLayer(unittest.TestCase): class TestJitSaveLoadNoParamLayer(unittest.TestCase):
def setUp(self): def setUp(self):
self.model_path = "model.jit_save_load_no_param_layer" self.model_path = "jit_save_load_no_param_layer/model"
# enable dygraph mode # enable dygraph mode
paddle.disable_static() paddle.disable_static()
......
...@@ -63,6 +63,8 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -63,6 +63,8 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.epoch_num = 1 self.epoch_num = 1
self.batch_size = 128 self.batch_size = 128
self.batch_num = 10 self.batch_num = 10
# enable static mode
paddle.enable_static()
def train_and_save_model(self, only_params=False): def train_and_save_model(self, only_params=False):
with new_program_scope(): with new_program_scope():
...@@ -136,13 +138,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -136,13 +138,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = None self.params_filename = None
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
config = paddle.SaveLoadConfig() load_param_dict, _ = fluid.load_dygraph(
config.separate_params = True self.save_dirname, model_filename=self.model_filename)
config.model_filename = self.model_filename
load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname, config) new_load_param_dict = paddle.load(
self.save_dirname, model_filename=self.model_filename)
self.check_load_state_dict(orig_param_dict, new_load_param_dict) self.check_load_state_dict(orig_param_dict, new_load_param_dict)
def test_load_with_param_filename(self): def test_load_with_param_filename(self):
...@@ -151,12 +152,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -151,12 +152,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = "static_mnist.params" self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
config = paddle.SaveLoadConfig() load_param_dict, _ = fluid.load_dygraph(
config.params_filename = self.params_filename self.save_dirname, params_filename=self.params_filename)
load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname, config) new_load_param_dict = paddle.load(
self.save_dirname, params_filename=self.params_filename)
self.check_load_state_dict(orig_param_dict, new_load_param_dict) self.check_load_state_dict(orig_param_dict, new_load_param_dict)
def test_load_with_model_and_param_filename(self): def test_load_with_model_and_param_filename(self):
...@@ -165,13 +166,16 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -165,13 +166,16 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = "static_mnist.params" self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
config = paddle.SaveLoadConfig() load_param_dict, _ = fluid.load_dygraph(
config.params_filename = self.params_filename self.save_dirname,
config.model_filename = self.model_filename params_filename=self.params_filename,
load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config) model_filename=self.model_filename)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname, config) new_load_param_dict = paddle.load(
self.save_dirname,
params_filename=self.params_filename,
model_filename=self.model_filename)
self.check_load_state_dict(orig_param_dict, new_load_param_dict) self.check_load_state_dict(orig_param_dict, new_load_param_dict)
def test_load_state_dict_from_save_params(self): def test_load_state_dict_from_save_params(self):
......
...@@ -20,8 +20,8 @@ __all__ = [ ...@@ -20,8 +20,8 @@ __all__ = [
] ]
__all__ += [ __all__ += [
'grad', 'LayerList', 'load', 'save', 'SaveLoadConfig', 'to_variable', 'grad', 'LayerList', 'load', 'save', 'to_variable', 'no_grad',
'no_grad', 'DataParallel' 'DataParallel'
] ]
__all__ += [ __all__ += [
...@@ -50,7 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS ...@@ -50,7 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS
from ..fluid.dygraph.base import grad #DEFINE_ALIAS from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from .io import save from .io import save
from .io import load from .io import load
from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS
from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS
from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS
......
...@@ -26,7 +26,9 @@ import paddle ...@@ -26,7 +26,9 @@ import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
from paddle.fluid.dygraph.io import INFER_MODEL_SUFFIX, INFER_PARAMS_SUFFIX, INFER_PARAMS_INFO_SUFFIX
__all__ = [ __all__ = [
'save', 'save',
...@@ -55,19 +57,16 @@ def _load_state_dict_from_save_inference_model(model_path, config): ...@@ -55,19 +57,16 @@ def _load_state_dict_from_save_inference_model(model_path, config):
# 2. load layer parameters & buffers # 2. load layer parameters & buffers
with fluid.dygraph.guard(): with fluid.dygraph.guard():
persistable_var_dict = _construct_params_and_buffers( persistable_var_dict = _construct_params_and_buffers(
model_path, model_path, programs, config.params_filename, append_suffix=False)
programs,
config.separate_params,
config.params_filename,
append_suffix=False)
# 3. construct state_dict # 3. construct state_dict
load_param_dict = dict() load_param_dict = dict()
for var_name in persistable_var_dict: for var_name in persistable_var_dict:
load_param_dict[var_name] = persistable_var_dict[var_name].numpy() load_param_dict[var_name] = persistable_var_dict[var_name].numpy()
# if __variables.info__ exists, we can recover structured_name # if *.info exists, we can recover structured_name
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME) var_info_filename = str(config.params_filename) + ".info"
var_info_path = os.path.join(model_path, var_info_filename)
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
with open(var_info_path, 'rb') as f: with open(var_info_path, 'rb') as f:
extra_var_info = pickle.load(f) extra_var_info = pickle.load(f)
...@@ -116,12 +115,99 @@ def _load_state_dict_from_save_params(model_path): ...@@ -116,12 +115,99 @@ def _load_state_dict_from_save_params(model_path):
return load_param_dict return load_param_dict
# NOTE(chenweihang): [ Handling of use cases of API paddle.load ]
# `paddle.load` may be used to load saved results of:
# 1. Expected cases:
# - need [full filename] when loading
# - paddle.save
# - paddle.static.save
# - paddle.fluid.save_dygraph
# - need [prefix] when loading [compatible for paddle 2.x]
# - paddle.jit.save
# - paddle.static.save_inference_model
# - need [directory] when loading [compatible for paddle 1.x]
# - paddle.fluid.io.save_inference_model
# - paddle.fluid.io.save_params/save_persistable
# 2. Error cases:
# - no error case
def _build_load_path_and_config(path, config):
# NOTE(chenweihang): If both [prefix save format] and [directory save format] exist,
# raise error, avoid confusing behavior
prefix_format_path = path + INFER_MODEL_SUFFIX
prefix_format_exist = os.path.exists(prefix_format_path)
directory_format_exist = os.path.isdir(path)
if prefix_format_exist and directory_format_exist:
raise ValueError(
"The %s.pdmodel and %s directory exist at the same time, "
"don't know which one to load, please make sure that the specified target "
"of ``path`` is unique." % (path, path))
elif not prefix_format_exist and not directory_format_exist:
error_msg = "The ``path`` (%s) to load model not exists."
# if current path is a prefix, and the path.pdparams or path.pdopt
# is exist, users may want use `paddle.load` load the result of
# `fluid.save_dygraph`, we raise error here for users
params_file_path = path + ".pdparams"
opti_file_path = path + ".pdopt"
if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
error_msg += " If you want to load the results saved by `fluid.save_dygraph`, " \
"please specify the full file name, not just the file name prefix. For " \
"example, it should be written as `paddle.load('model.pdparams')` instead of " \
"`paddle.load('model')`."
raise ValueError(error_msg % path)
else:
if prefix_format_exist:
file_prefix = os.path.basename(path)
model_path = os.path.dirname(path)
if config.model_filename is not None:
warnings.warn(
"When loading the result saved with the "
"specified file prefix, the ``model_filename`` config does "
"not take effect.")
config.model_filename = file_prefix + INFER_MODEL_SUFFIX
if config.params_filename is not None:
warnings.warn(
"When loading the result saved with the "
"specified file prefix, the ``params_filename`` config does "
"not take effect.")
config.params_filename = file_prefix + INFER_PARAMS_SUFFIX
else:
# Compatible with the old save_inference_model format
model_path = path
return model_path, config
def _parse_load_config(configs):
supported_configs = ['model_filename', 'params_filename', 'keep_name_table']
# input check
for key in configs:
if key not in supported_configs:
raise ValueError(
"The additional config (%s) of `paddle.load` is not supported."
% key)
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.model_filename = configs.get('model_filename', None)
inner_config.params_filename = configs.get('params_filename', None)
inner_config.keep_name_table = configs.get('keep_name_table', None)
return inner_config
def save(obj, path): def save(obj, path):
''' '''
Save an object to the specified path. Save an object to the specified path.
.. note:: .. note::
Now only supports save ``state_dict`` of Layer or Optimizer. Now only supports save ``state_dict`` of Layer or Optimizer.
.. note::
``paddle.save`` will not add a suffix to the saved results,
but we recommend that you use the following paddle standard suffixes:
1. for ``Layer.state_dict`` -> ``.pdparams``
2. for ``Optimizer.state_dict`` -> ``.pdopt``
Args: Args:
obj(Object) : The object to be saved. obj(Object) : The object to be saved.
...@@ -178,7 +264,7 @@ def save(obj, path): ...@@ -178,7 +264,7 @@ def save(obj, path):
pickle.dump(saved_obj, f, protocol=2) pickle.dump(saved_obj, f, protocol=2)
def load(path, config=None): def load(path, **configs):
''' '''
Load an object can be used in paddle from specified path. Load an object can be used in paddle from specified path.
...@@ -186,21 +272,39 @@ def load(path, config=None): ...@@ -186,21 +272,39 @@ def load(path, config=None):
Now only supports load ``state_dict`` of Layer or Optimizer. Now only supports load ``state_dict`` of Layer or Optimizer.
.. note:: .. note::
``paddle.load`` supports loading ``state_dict`` from the result of several ``paddle.load`` supports loading ``state_dict`` of Layer or Optimizer from
paddle1.x save APIs in static mode, but due to some historical reasons, the result of other save APIs except ``paddle.load`` , but the argument
if you load ``state_dict`` from the saved result of ``path`` format is different:
``paddle.static.save_inference_model/paddle.fluid.io.save_params/paddle.fluid.io.save_persistables`` , 1. loading from ``paddle.static.save`` or ``paddle.Model().save(training=True)`` ,
``path`` needs to be a complete file name, such as ``model.pdparams`` or
``model.pdopt`` ;
2. loading from ``paddle.jit.save`` or ``paddle.static.save_inference_model``
or ``paddle.Model().save(training=False)`` , ``path`` need to be a file prefix,
such as ``model/mnist``, and ``paddle.load`` will get information from
``mnist.pdmodel`` and ``mnist.pdiparams`` ;
3. loading from paddle 1.x APIs ``paddle.fluid.io.save_inference_model`` or
``paddle.fluid.io.save_params/save_persistables`` , ``path`` need to be a
directory, such as ``model`` and model is a directory.
.. note::
If you load ``state_dict`` from the saved result of
``paddle.static.save`` or ``paddle.static.save_inference_model`` ,
the structured variable name will cannot be restored. You need to set the argument the structured variable name will cannot be restored. You need to set the argument
``use_structured_name=False`` when using ``Layer.set_state_dict`` later. ``use_structured_name=False`` when using ``Layer.set_state_dict`` later.
Args: Args:
path(str) : The path to load the target object. Generally, the path is the target path(str) : The path to load the target object. Generally, the path is the target
file path, when compatible with loading the saved results of file path. When compatible with loading the saved results other APIs, the path
``paddle.jit.save/paddle.static.save_inference_model`` , the path is a directory. can be a file prefix or directory.
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` **configs (dict, optional): other load configuration options for compatibility. We do not
object that specifies additional configuration options, these options recommend using these configurations, they may be removed in the future. If not necessary,
are for compatibility with ``paddle.jit.save/paddle.static.save_inference_model`` DO NOT use them. Default None.
formats. Default None. The following options are currently supported:
(1) model_filename (string): The inference model file name of the paddle 1.x
``save_inference_model`` save format. Default file name is :code:`__model__` .
(2) params_filename (string): The persistable variables file name of the paddle 1.x
``save_inference_model`` save format. No default file name, save variables separately
by default.
Returns: Returns:
Object(Object): a target object can be used in paddle Object(Object): a target object can be used in paddle
...@@ -227,26 +331,9 @@ def load(path, config=None): ...@@ -227,26 +331,9 @@ def load(path, config=None):
load_layer_state_dict = paddle.load("emb.pdparams") load_layer_state_dict = paddle.load("emb.pdparams")
load_opt_state_dict = paddle.load("adam.pdopt") load_opt_state_dict = paddle.load("adam.pdopt")
''' '''
# 1. input check
if not os.path.exists(path):
error_msg = "The path `%s` does not exist."
# if current path is a prefix, and the path.pdparams or path.pdopt
# is exist, users may want use `paddle.load` load the result of
# `fluid.save_dygraph`, we raise error here for users
params_file_path = path + ".pdparams"
opti_file_path = path + ".pdopt"
if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
error_msg += " If you want to load the results saved by `fluid.save_dygraph`, " \
"please specify the full file name, not just the file name prefix. For " \
"example, it should be written as `paddle.load('model.pdparams')` instead of " \
"`paddle.load('model')`."
raise ValueError(error_msg % path)
if config is None:
config = paddle.SaveLoadConfig()
# 2. load target
load_result = None load_result = None
config = _parse_load_config(configs)
if os.path.isfile(path): if os.path.isfile(path):
# we think path is file means this file is created by paddle.save # we think path is file means this file is created by paddle.save
with open(path, 'rb') as f: with open(path, 'rb') as f:
...@@ -255,16 +342,15 @@ def load(path, config=None): ...@@ -255,16 +342,15 @@ def load(path, config=None):
if not config.keep_name_table and "StructuredToParameterName@@" in load_result: if not config.keep_name_table and "StructuredToParameterName@@" in load_result:
del load_result["StructuredToParameterName@@"] del load_result["StructuredToParameterName@@"]
elif os.path.isdir(path): else:
# we think path is directory means compatible with loading # file prefix and directory are compatible cases
# store results of static mode related save APIs model_path, config = _build_load_path_and_config(path, config)
# check whether model file exists # check whether model file exists
if config.model_filename is None: if config.model_filename is None:
model_filename = '__model__' model_filename = '__model__'
else: else:
model_filename = config.model_filename model_filename = config.model_filename
model_file_path = os.path.join(path, model_filename) model_file_path = os.path.join(model_path, model_filename)
if os.path.exists(model_file_path): if os.path.exists(model_file_path):
# Load state dict by `jit.save/io.save_inference_model` save format # Load state dict by `jit.save/io.save_inference_model` save format
...@@ -274,7 +360,7 @@ def load(path, config=None): ...@@ -274,7 +360,7 @@ def load(path, config=None):
# `save_inference_model` not save structured name, we need to remind # `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_state_dict` # the user to configure the `use_structured_name` argument when `set_state_dict`
# NOTE(chenweihang): `jit.save` doesn't save optimizer state # NOTE(chenweihang): `jit.save` doesn't save optimizer state
load_result = _load_state_dict_from_save_inference_model(path, load_result = _load_state_dict_from_save_inference_model(model_path,
config) config)
else: else:
# load state dict by `io.save_params/persistables` save format # load state dict by `io.save_params/persistables` save format
...@@ -283,9 +369,6 @@ def load(path, config=None): ...@@ -283,9 +369,6 @@ def load(path, config=None):
# mapping info will lost, so users need to give variable list, but users build # mapping info will lost, so users need to give variable list, but users build
# variable list in dygraph mode is difficult, we recommend users to use # variable list in dygraph mode is difficult, we recommend users to use
# paddle.static.load_program_state in this case # paddle.static.load_program_state in this case
load_result = _load_state_dict_from_save_params(path) load_result = _load_state_dict_from_save_params(model_path)
else:
raise ValueError(
"Unsupported path format, now only supports file or directory.")
return load_result return load_result
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册