未验证 提交 209273e6 编写于 作者: C Chen Weihang 提交者: GitHub

Support load state dict form `inference model` format save result (#26718)

* support load infer model format state dict

* add unittests

* remove keep name table

* recolve circle inport

* fix compatible problem

* recover unittest

* polish doc and comment
上级 bcdbac17
...@@ -230,6 +230,7 @@ from .framework import grad #DEFINE_ALIAS ...@@ -230,6 +230,7 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS from .framework import NoamDecay #DEFINE_ALIAS
......
...@@ -16,13 +16,16 @@ from __future__ import print_function ...@@ -16,13 +16,16 @@ from __future__ import print_function
import os import os
import collections import collections
import functools
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
import pickle import pickle
import six import six
from . import learning_rate_scheduler from . import learning_rate_scheduler
import warnings import warnings
from .. import core from .. import core
from paddle.fluid.dygraph.io import VARIABLE_FILENAME, EXTRA_VAR_INFO_FILENAME, _load_persistable_vars from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
__all__ = [ __all__ = [
'save_dygraph', 'save_dygraph',
...@@ -30,6 +33,37 @@ __all__ = [ ...@@ -30,6 +33,37 @@ __all__ = [
] ]
# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def __warn_and_build_configs__(keep_name_table):
warnings.warn(
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
DeprecationWarning)
configs = SaveLoadConfig()
configs.keep_name_table = keep_name_table
return configs
# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['configs'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass
return func(*args, **kwargs)
return wrapper
@dygraph_only @dygraph_only
def save_dygraph(state_dict, model_path): def save_dygraph(state_dict, model_path):
''' '''
...@@ -100,17 +134,27 @@ def save_dygraph(state_dict, model_path): ...@@ -100,17 +134,27 @@ def save_dygraph(state_dict, model_path):
# TODO(qingqing01): remove dygraph_only to support loading static model. # TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready. # maybe need to unify the loading interface after 2.0 API is ready.
#@dygraph_only # @dygraph_only
def load_dygraph(model_path, keep_name_table=False): @deprecate_keep_name_table
def load_dygraph(model_path, configs=None):
''' '''
:api_attr: imperative :api_attr: imperative
Load parameter state_dict from disk. Load parameter state dict from disk.
.. note::
Due to some historical reasons, if you load ``state_dict`` from the saved
result of `paddle.io.save_inference_model`, the structured variable name
will cannot be restored. You need to set the argument `use_structured_name=False`
when using `Layer.set_state_dict` later.
Args: Args:
model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams') model_path(str) : The file prefix store the state_dict.
keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict. (The path should Not contain suffix '.pdparams')
Default : False configs (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
object that specifies additional configuration options, these options
are for compatibility with ``jit.save/io.save_inference_model`` formats.
Default None.
Returns: Returns:
state_dict(dict) : the dict store the state_dict state_dict(dict) : the dict store the state_dict
...@@ -118,23 +162,27 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -118,23 +162,27 @@ def load_dygraph(model_path, keep_name_table=False):
Examples: Examples:
.. code-block:: python .. code-block:: python
import paddle.fluid as fluid import paddle
with fluid.dygraph.guard(): paddle.disable_static()
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict() emb = paddle.nn.Embedding([10, 10])
fluid.save_dygraph( state_dict, "paddle_dy")
adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000), state_dict = emb.state_dict()
parameter_list = emb.parameters() ) paddle.save(state_dict, "paddle_dy")
state_dict = adam.state_dict()
fluid.save_dygraph( state_dict, "paddle_dy")
para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy") scheduler = paddle.optimizer.lr_scheduler.NoamLR(
d_model=0.01, warmup_steps=100, verbose=True)
adam = paddle.optimizer.Adam(
learning_rate=scheduler,
parameters=emb.parameters())
state_dict = adam.state_dict()
paddle.save(state_dict, "paddle_dy")
''' para_state_dict, opti_state_dict = paddle.load("paddle_dy")
'''
# deal with argument `model_path`
model_prefix = model_path model_prefix = model_path
if model_prefix.endswith(".pdparams"): if model_prefix.endswith(".pdparams"):
model_prefix = model_prefix[:-9] model_prefix = model_prefix[:-9]
...@@ -145,66 +193,44 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -145,66 +193,44 @@ def load_dygraph(model_path, keep_name_table=False):
opti_dict = None opti_dict = None
params_file_path = model_prefix + ".pdparams" params_file_path = model_prefix + ".pdparams"
opti_file_path = model_prefix + ".pdopt" opti_file_path = model_prefix + ".pdopt"
# deal with argument `configs`
if configs is None:
configs = SaveLoadConfig()
if not os.path.exists(params_file_path) and not os.path.exists( if not os.path.exists(params_file_path) and not os.path.exists(
opti_file_path): opti_file_path):
# Load state dict by `jit.save` save format # Load state dict by `jit.save/io.save_inference_model` save format
# TODO(chenweihang): [Why not support `io.save_infernece_model` save format here] # NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
# The model saved by `save_inference_model` does not completely correspond to # The model saved by `save_inference_model` does not completely correspond to
# the information required by the `state_dict` under the dygraph. # the information required by the `state_dict` under the dygraph.
# Although we reluctantly restore the `state_dict` in some scenarios, # `save_inference_model` not save structured name, we need to remind
# this may not be complete and there are some limitations, so this function # the user to configure the `use_structured_name` argument when `set_state_dict`
# will be considered later. The limitations include: # NOTE(chenweihang): `jit.save` doesn't save optimizer state
# 1. `save_inference_model` not save structured name, we need to remind
# the user to configure the `use_structured_name` argument when `set_dict`,
# but this argument is currently not public
# 2. if `save_inference_model` save all persistable variables in a single file,
# user need to give the variable name list to load `state_dict`
# 1. check model path # 1. check model path
if not os.path.isdir(model_prefix): if not os.path.isdir(model_prefix):
raise ValueError("Model saved directory '%s' is not exists." % raise ValueError("Model saved directory '%s' is not exists." %
model_prefix) model_prefix)
# 2. load `__variables.info__`
var_info_path = os.path.join(model_prefix, EXTRA_VAR_INFO_FILENAME) # 2. load program desc & construct _ProgramHolder
if not os.path.exists(var_info_path): programs = _construct_program_holders(model_path,
raise RuntimeError( configs.model_filename)
"No target can be loaded. Now only supports loading `state_dict` from "
"the result saved by `imperative.save` and `imperative.jit.save`." # 3. load layer parameters & buffers
) # NOTE: using fluid.dygraph.guard() here will cause import error in py2
with open(var_info_path, 'rb') as f: with guard():
extra_var_info = pickle.load(f) persistable_var_dict = _construct_params_and_buffers(
# 3. load `__variables__` model_prefix,
# TODO(chenweihang): now only supports loading from default save format: programs,
# - all persistable vars saved in one file named `__variables__` configs.separate_params,
# for other case, we may need to modify the arguments of this API configs.params_filename,
var_file_path = os.path.join(model_prefix, VARIABLE_FILENAME) append_suffix=False)
if not os.path.exists(var_file_path):
raise RuntimeError( # 4. construct state_dict
"The parameter file to be loaded was not found. " para_dict = dict()
"Now only supports loading from the default save format, " for var_name in persistable_var_dict:
"and does not support custom params_filename and " para_dict[var_name] = persistable_var_dict[var_name].numpy()
"save parameters separately.")
# 4. load all persistable vars
load_var_list = []
for name in sorted(extra_var_info):
var = _varbase_creator(name=name, persistable=True)
load_var_list.append(var)
_dygraph_tracer().trace_op(
type='load_combine',
inputs={},
outputs={'Out': load_var_list},
attrs={'file_path': var_file_path})
# 5. construct state_dict
para_dict = dict()
for var in load_var_list:
structured_name = extra_var_info[var.name].get('structured_name',
None)
if structured_name is None:
raise RuntimeError(
"Cannot find saved variable (%s)'s structured name in saved model.",
var.name)
para_dict[structured_name] = var.numpy()
# NOTE: `jit.save` doesn't save optimizer state
else: else:
# Load state dict by `save_dygraph` save format # Load state dict by `save_dygraph` save format
para_dict = {} para_dict = {}
...@@ -213,7 +239,7 @@ def load_dygraph(model_path, keep_name_table=False): ...@@ -213,7 +239,7 @@ def load_dygraph(model_path, keep_name_table=False):
para_dict = pickle.load(f) if six.PY2 else pickle.load( para_dict = pickle.load(f) if six.PY2 else pickle.load(
f, encoding='latin1') f, encoding='latin1')
if not keep_name_table and "StructuredToParameterName@@" in para_dict: if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict:
del para_dict["StructuredToParameterName@@"] del para_dict["StructuredToParameterName@@"]
if os.path.exists(opti_file_path): if os.path.exists(opti_file_path):
......
...@@ -488,6 +488,15 @@ def _load_persistable_vars(model_path, ...@@ -488,6 +488,15 @@ def _load_persistable_vars(model_path,
return load_var_dict return load_var_dict
# NOTE(chenweihang): to adapt paddle.load to get state_dict
def _remove_varname_suffix(var_dict, program_holder):
no_suffix_var_dict = dict()
for var_name in var_dict:
no_suffix_name = program_holder._suffix_varname_dict[var_name]
no_suffix_var_dict[no_suffix_name] = var_dict[var_name]
return no_suffix_var_dict
def _construct_program_holders(model_path, model_filename=None): def _construct_program_holders(model_path, model_filename=None):
# make sure the path has been checked # make sure the path has been checked
program_holder_dict = dict() program_holder_dict = dict()
...@@ -517,7 +526,8 @@ def _construct_program_holders(model_path, model_filename=None): ...@@ -517,7 +526,8 @@ def _construct_program_holders(model_path, model_filename=None):
def _construct_params_and_buffers(model_path, def _construct_params_and_buffers(model_path,
programs, programs,
separate_params=False, separate_params=False,
params_filename=None): params_filename=None,
append_suffix=True):
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME) var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
if os.path.exists(var_info_path): if os.path.exists(var_info_path):
var_dict = _load_persistable_vars(model_path, var_info_path, var_dict = _load_persistable_vars(model_path, var_info_path,
...@@ -526,6 +536,10 @@ def _construct_params_and_buffers(model_path, ...@@ -526,6 +536,10 @@ def _construct_params_and_buffers(model_path,
else: else:
var_dict = _load_persistable_vars_by_program( var_dict = _load_persistable_vars_by_program(
model_path, programs['forward'], params_filename) model_path, programs['forward'], params_filename)
if not append_suffix:
var_dict = _remove_varname_suffix(var_dict, programs['forward'])
return var_dict return var_dict
...@@ -685,7 +699,7 @@ class TranslatedLayer(layers.Layer): ...@@ -685,7 +699,7 @@ class TranslatedLayer(layers.Layer):
# 1. load program desc & construct _ProgramHolder # 1. load program desc & construct _ProgramHolder
programs = _construct_program_holders(model_path, model_filename) programs = _construct_program_holders(model_path, model_filename)
# 2. load layer parameters & parameter attributes # 2. load layer parameters & buffers
persistable_vars = _construct_params_and_buffers( persistable_vars = _construct_params_and_buffers(
model_path, programs, separate_params, params_filename) model_path, programs, separate_params, params_filename)
......
...@@ -293,6 +293,8 @@ class SaveLoadConfig(object): ...@@ -293,6 +293,8 @@ class SaveLoadConfig(object):
self._model_filename = None self._model_filename = None
self._params_filename = None self._params_filename = None
self._separate_params = False self._separate_params = False
# used for `paddle.load`
self._keep_name_table = False
# NOTE: Users rarely use following configs, so these configs are not open to users, # NOTE: Users rarely use following configs, so these configs are not open to users,
# reducing user learning costs, but we retain the configuration capabilities # reducing user learning costs, but we retain the configuration capabilities
...@@ -600,6 +602,54 @@ class SaveLoadConfig(object): ...@@ -600,6 +602,54 @@ class SaveLoadConfig(object):
% type(value)) % type(value))
self._separate_params = value self._separate_params = value
@property
def keep_name_table(self):
"""
Configures whether keep ``structured_name -> parameter_name`` dict in loaded state dict.
This dict is the debugging information saved when call `paddle.save`.
It is generally only used for debugging and does not affect the actual training or inference.
By default, it will not be retained in `paddle.load` result. Default: False.
.. note::
Only used for ``paddle.load``.
Examples:
.. code-block:: python
import paddle
paddle.disable_static()
linear = paddle.nn.Linear(5, 1)
state_dict = linear.state_dict()
paddle.save(state_dict, "paddle_dy")
configs = paddle.SaveLoadConfig()
configs.keep_name_table = True
para_state_dict, _ = paddle.load("paddle_dy", configs)
print(para_state_dict)
# the name_table is 'StructuredToParameterName@@'
# {'bias': array([0.], dtype=float32),
# 'StructuredToParameterName@@':
# {'bias': u'linear_0.b_0', 'weight': u'linear_0.w_0'},
# 'weight': array([[ 0.04230034],
# [-0.1222527 ],
# [ 0.7392676 ],
# [-0.8136974 ],
# [ 0.01211023]], dtype=float32)}
"""
return self._keep_name_table
@keep_name_table.setter
def keep_name_table(self, value):
if not isinstance(value, bool):
raise TypeError(
"The SaveLoadConfig.keep_name_table should be bool value, but received input's type is %s."
% type(value))
self._keep_name_table = value
@switch_to_static_graph @switch_to_static_graph
def save(layer, model_path, input_spec=None, configs=None): def save(layer, model_path, input_spec=None, configs=None):
......
...@@ -43,7 +43,7 @@ class TestDirectory(unittest.TestCase): ...@@ -43,7 +43,7 @@ class TestDirectory(unittest.TestCase):
'paddle.distributed.prepare_context', 'paddle.DataParallel', 'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static', 'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer', 'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.jit.SaveLoadConfig', 'paddle.jit.save', 'paddle.jit.load', 'paddle.SaveLoadConfig',
'paddle.NoamDecay', 'paddle.PiecewiseDecay', 'paddle.NoamDecay', 'paddle.PiecewiseDecay',
'paddle.NaturalExpDecay', 'paddle.ExponentialDecay', 'paddle.NaturalExpDecay', 'paddle.ExponentialDecay',
'paddle.InverseTimeDecay', 'paddle.PolynomialDecay', 'paddle.InverseTimeDecay', 'paddle.PolynomialDecay',
......
...@@ -912,6 +912,22 @@ class TestDygraphPtbRnn(unittest.TestCase): ...@@ -912,6 +912,22 @@ class TestDygraphPtbRnn(unittest.TestCase):
para_state_dict, opti_state_dict = paddle.load( para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy.pdopt')) os.path.join('saved_dy', 'emb_dy.pdopt'))
def test_load_compatible_with_keep_name_table(self):
with fluid.dygraph.guard():
emb = fluid.dygraph.Embedding([10, 10])
state_dict = emb.state_dict()
paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy'))
para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy'), True)
self.assertTrue(para_state_dict != None)
self.assertTrue(opti_state_dict == None)
para_state_dict, opti_state_dict = paddle.load(
os.path.join('saved_dy', 'emb_dy'), keep_name_table=True)
self.assertTrue(para_state_dict != None)
self.assertTrue(opti_state_dict == None)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -183,25 +183,6 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -183,25 +183,6 @@ class TestJitSaveLoad(unittest.TestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path) model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_load_dygraph_no_var_info(self):
model_path = "model.test_jit_save_load.no_var_info"
self.train_and_save_model(model_path=model_path)
# remove `__variables.info__`
var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME)
os.remove(var_info_path)
new_layer = LinearNet(784, 1)
with self.assertRaises(RuntimeError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
def test_load_dygraph_not_var_file(self):
model_path = "model.test_jit_save_load.no_var_file"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.params_filename = "__params__"
self.train_and_save_model(model_path=model_path, configs=configs)
new_layer = LinearNet(784, 1)
with self.assertRaises(RuntimeError):
model_dict, _ = fluid.dygraph.load_dygraph(model_path)
class LinearNetMultiInput(fluid.dygraph.Layer): class LinearNetMultiInput(fluid.dygraph.Layer):
def __init__(self, in_size, out_size): def __init__(self, in_size, out_size):
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import six
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from test_imperative_base import new_program_scope
def convolutional_neural_network(img):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
return prediction
def static_train_net(img, label):
prediction = convolutional_neural_network(img)
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
optimizer = fluid.optimizer.SGD(learning_rate=0.001)
optimizer.minimize(avg_loss)
return prediction, avg_loss
class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
def setUp(self):
self.seed = 90
self.epoch_num = 1
self.batch_size = 128
self.batch_num = 10
def train_and_save_model(self):
with new_program_scope():
startup_program = fluid.default_startup_program()
main_program = fluid.default_main_program()
img = fluid.data(
name='img', shape=[None, 1, 28, 28], dtype='float32')
label = fluid.data(name='label', shape=[None, 1], dtype='int64')
prediction, avg_loss = static_train_net(img, label)
place = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
exe = fluid.Executor(place)
feeder = fluid.DataFeeder(feed_list=[img, label], place=place)
exe.run(startup_program)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=100),
batch_size=self.batch_size)
for _ in range(0, self.epoch_num):
for batch_id, data in enumerate(train_reader()):
exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[avg_loss])
if batch_id > self.batch_num:
break
static_param_dict = {}
for param in fluid.default_main_program().all_parameters():
static_param_dict[param.name] = fluid.executor._fetch_var(
param.name)
fluid.io.save_inference_model(
self.save_dirname, ["img"], [prediction],
exe,
model_filename=self.model_filename,
params_filename=self.params_filename)
return static_param_dict
def check_load_state_dict(self, orig_dict, load_dict):
for var_name, value in six.iteritems(orig_dict):
self.assertTrue(np.array_equal(value, load_dict[var_name]))
def test_load_default(self):
self.save_dirname = "static_mnist.load_state_dict.default"
self.model_filename = None
self.params_filename = None
orig_param_dict = self.train_and_save_model()
configs = paddle.SaveLoadConfig()
configs.separate_params = True
load_param_dict, _ = paddle.load(self.save_dirname, configs)
self.check_load_state_dict(orig_param_dict, load_param_dict)
def test_load_with_model_filename(self):
self.save_dirname = "static_mnist.load_state_dict.model_filename"
self.model_filename = "static_mnist.model"
self.params_filename = None
orig_param_dict = self.train_and_save_model()
configs = paddle.SaveLoadConfig()
configs.separate_params = True
configs.model_filename = self.model_filename
load_param_dict, _ = paddle.load(self.save_dirname, configs)
self.check_load_state_dict(orig_param_dict, load_param_dict)
def test_load_with_param_filename(self):
self.save_dirname = "static_mnist.load_state_dict.param_filename"
self.model_filename = None
self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model()
configs = paddle.SaveLoadConfig()
configs.params_filename = self.params_filename
load_param_dict, _ = paddle.load(self.save_dirname, configs)
self.check_load_state_dict(orig_param_dict, load_param_dict)
def test_load_with_model_and_param_filename(self):
self.save_dirname = "static_mnist.load_state_dict.model_and_param_filename"
self.model_filename = "static_mnist.model"
self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model()
configs = paddle.SaveLoadConfig()
configs.params_filename = self.params_filename
configs.model_filename = self.model_filename
load_param_dict, _ = paddle.load(self.save_dirname, configs)
self.check_load_state_dict(orig_param_dict, load_param_dict)
if __name__ == '__main__':
unittest.main()
...@@ -20,8 +20,8 @@ __all__ = [ ...@@ -20,8 +20,8 @@ __all__ = [
] ]
__all__ += [ __all__ += [
'grad', 'LayerList', 'load', 'save', 'to_variable', 'no_grad', 'grad', 'LayerList', 'load', 'save', 'SaveLoadConfig', 'to_variable',
'DataParallel' 'no_grad', 'DataParallel'
] ]
__all__ += [ __all__ += [
...@@ -50,6 +50,7 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS ...@@ -50,6 +50,7 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS
from ..fluid.dygraph.base import grad #DEFINE_ALIAS from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS
from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS
from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS
from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS
from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
from ..fluid.dygraph.jit import save #DEFINE_ALIAS from ..fluid.dygraph.jit import save #DEFINE_ALIAS
from ..fluid.dygraph.jit import load #DEFINE_ALIAS from ..fluid.dygraph.jit import load #DEFINE_ALIAS
from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS
from ..fluid.dygraph.jit import TracedLayer #DEFINE_ALIAS from ..fluid.dygraph.jit import TracedLayer #DEFINE_ALIAS
from ..fluid.dygraph.jit import set_code_level #DEFINE_ALIAS from ..fluid.dygraph.jit import set_code_level #DEFINE_ALIAS
from ..fluid.dygraph.jit import set_verbosity #DEFINE_ALIAS from ..fluid.dygraph.jit import set_verbosity #DEFINE_ALIAS
...@@ -23,6 +22,6 @@ from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS ...@@ -23,6 +22,6 @@ from ..fluid.dygraph import ProgramTranslator #DEFINE_ALIAS
from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS from ..fluid.dygraph.io import TranslatedLayer #DEFINE_ALIAS
__all__ = [ __all__ = [
'save', 'load', 'SaveLoadConfig', 'TracedLayer', 'to_static', 'save', 'load', 'TracedLayer', 'to_static', 'ProgramTranslator',
'ProgramTranslator', 'TranslatedLayer', 'set_code_level', 'set_verbosity' 'TranslatedLayer', 'set_code_level', 'set_verbosity'
] ]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册