提交 24a33bed 编写于 作者: C Chen Weihang

replace config by kwargs

上级 6b727e08
...@@ -234,7 +234,6 @@ from .framework import grad #DEFINE_ALIAS ...@@ -234,7 +234,6 @@ from .framework import grad #DEFINE_ALIAS
from .framework import no_grad #DEFINE_ALIAS from .framework import no_grad #DEFINE_ALIAS
from .framework import save #DEFINE_ALIAS from .framework import save #DEFINE_ALIAS
from .framework import load #DEFINE_ALIAS from .framework import load #DEFINE_ALIAS
from .framework import SaveLoadConfig #DEFINE_ALIAS
from .framework import DataParallel #DEFINE_ALIAS from .framework import DataParallel #DEFINE_ALIAS
from .framework import NoamDecay #DEFINE_ALIAS from .framework import NoamDecay #DEFINE_ALIAS
......
...@@ -24,7 +24,7 @@ from . import learning_rate_scheduler ...@@ -24,7 +24,7 @@ from . import learning_rate_scheduler
import warnings import warnings
from .. import core from .. import core
from .base import guard from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME
__all__ = [ __all__ = [
...@@ -33,35 +33,27 @@ __all__ = [ ...@@ -33,35 +33,27 @@ __all__ = [
] ]
# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table, def _parse_load_config(configs):
# ensure compatibility when user still use keep_name_table argument supported_configs = [
def deprecate_keep_name_table(func): 'model_filename', 'params_filename', 'separate_params',
@functools.wraps(func) 'keep_name_table'
def wrapper(*args, **kwargs): ]
def __warn_and_build_configs__(keep_name_table):
warnings.warn( # input check
"The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.", for key in configs:
DeprecationWarning) if key not in supported_configs:
config = SaveLoadConfig() raise ValueError(
config.keep_name_table = keep_name_table "The additional config (%s) of `paddle.fluid.load_dygraph` is not supported."
return config % (key))
# deal with arg `keep_name_table`
if len(args) > 1 and isinstance(args[1], bool):
args = list(args)
args[1] = __warn_and_build_configs__(args[1])
# deal with kwargs
elif 'keep_name_table' in kwargs:
kwargs['config'] = __warn_and_build_configs__(kwargs[
'keep_name_table'])
kwargs.pop('keep_name_table')
else:
# do nothing
pass
return func(*args, **kwargs) # construct inner config
inner_config = _SaveLoadConfig()
inner_config.model_filename = configs.get('model_filename', None)
inner_config.params_filename = configs.get('params_filename', None)
inner_config.separate_params = configs.get('separate_params', None)
inner_config.keep_name_table = configs.get('keep_name_table', None)
return wrapper return inner_config
@dygraph_only @dygraph_only
...@@ -135,9 +127,7 @@ def save_dygraph(state_dict, model_path): ...@@ -135,9 +127,7 @@ def save_dygraph(state_dict, model_path):
# TODO(qingqing01): remove dygraph_only to support loading static model. # TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready. # maybe need to unify the loading interface after 2.0 API is ready.
# @dygraph_only # @dygraph_only
@deprecate_save_load_configs def load_dygraph(model_path, **configs):
@deprecate_keep_name_table
def load_dygraph(model_path, config=None):
''' '''
:api_attr: imperative :api_attr: imperative
...@@ -152,10 +142,20 @@ def load_dygraph(model_path, config=None): ...@@ -152,10 +142,20 @@ def load_dygraph(model_path, config=None):
Args: Args:
model_path(str) : The file prefix store the state_dict. model_path(str) : The file prefix store the state_dict.
(The path should Not contain suffix '.pdparams') (The path should Not contain suffix '.pdparams')
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` configs (dict, optional): other save configuration options for compatibility. We do not
object that specifies additional configuration options, these options recommend using these configurations, if not necessary, DO NOT use them. Default None.
are for compatibility with ``jit.save/io.save_inference_model`` formats. The following options are currently supported:
Default None. (1) model_filename (string): The filename to load the translated program of target Layer.
Default filename is :code:`__model__` .
(2) params_filename (string): The filename to load all persistable variables in target Layer.
Default file name is :code:`__variables__` .
(3) separate_params (bool): Configure whether to load the Layer parameters from separete files.
If True, each parameter will be loaded from a file separately, the file name is the parameter name,
and the params_filename configuration will not take effect. Default False.
(4) keep_name_table (bool): Configures whether keep ``structured_name -> parameter_name`` dict in
loaded state dict. This dict is the debugging information saved when call ``paddle.fluid.save_dygraph`` .
It is generally only used for debugging and does not affect the actual training or inference.
By default, it will not be retained in ``paddle.fluid.load_dygraph`` result. Default: False.
Returns: Returns:
state_dict(dict) : the dict store the state_dict state_dict(dict) : the dict store the state_dict
...@@ -196,8 +196,7 @@ def load_dygraph(model_path, config=None): ...@@ -196,8 +196,7 @@ def load_dygraph(model_path, config=None):
opti_file_path = model_prefix + ".pdopt" opti_file_path = model_prefix + ".pdopt"
# deal with argument `config` # deal with argument `config`
if config is None: config = _parse_load_config(configs)
config = SaveLoadConfig()
if os.path.exists(params_file_path) or os.path.exists(opti_file_path): if os.path.exists(params_file_path) or os.path.exists(opti_file_path):
# Load state dict by `save_dygraph` save format # Load state dict by `save_dygraph` save format
......
此差异已折叠。
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
from paddle.fluid.dygraph.jit import SaveLoadConfig from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import TranslatedLayer from paddle.fluid.dygraph.io import TranslatedLayer
...@@ -31,7 +31,7 @@ class StaticModelRunner(object): ...@@ -31,7 +31,7 @@ class StaticModelRunner(object):
""" """
def __new__(cls, model_dir, model_filename=None, params_filename=None): def __new__(cls, model_dir, model_filename=None, params_filename=None):
configs = SaveLoadConfig() configs = _SaveLoadConfig()
if model_filename is not None: if model_filename is not None:
configs.model_filename = model_filename configs.model_filename = model_filename
if params_filename is not None: if params_filename is not None:
......
...@@ -498,13 +498,11 @@ def do_train(args, to_static): ...@@ -498,13 +498,11 @@ def do_train(args, to_static):
step += 1 step += 1
# save inference model # save inference model
if to_static: if to_static:
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.output_spec = [crf_decode]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
model_path=args.model_save_dir, model_path=args.model_save_dir,
input_spec=[words, length], input_spec=[words, length],
configs=configs) output_spec=[crf_decode])
else: else:
fluid.dygraph.save_dygraph(model.state_dict(), args.dy_param_path) fluid.dygraph.save_dygraph(model.state_dict(), args.dy_param_path)
......
...@@ -218,13 +218,11 @@ class TestMNISTWithToStatic(TestMNIST): ...@@ -218,13 +218,11 @@ class TestMNISTWithToStatic(TestMNIST):
def check_jit_save_load(self, model, inputs, input_spec, to_static, gt_out): def check_jit_save_load(self, model, inputs, input_spec, to_static, gt_out):
if to_static: if to_static:
infer_model_path = "./test_mnist_inference_model_by_jit_save" infer_model_path = "./test_mnist_inference_model_by_jit_save"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.output_spec = [gt_out]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
model_path=infer_model_path, model_path=infer_model_path,
input_spec=input_spec, input_spec=input_spec,
configs=configs) output_spec=[gt_out])
# load in static mode # load in static mode
static_infer_out = self.jit_load_and_run_inference_static( static_infer_out = self.jit_load_and_run_inference_static(
infer_model_path, inputs) infer_model_path, inputs)
......
...@@ -67,13 +67,11 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -67,13 +67,11 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
layer.clear_gradients() layer.clear_gradients()
# test for saving model in dygraph.guard # test for saving model in dygraph.guard
infer_model_dir = "./test_dy2stat_save_inference_model_in_guard" infer_model_dir = "./test_dy2stat_save_inference_model_in_guard"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.output_spec = [pred]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=layer, layer=layer,
model_path=infer_model_dir, model_path=infer_model_dir,
input_spec=[x], input_spec=[x],
configs=configs) output_spec=[pred])
# Check the correctness of the inference # Check the correctness of the inference
dygraph_out, _ = layer(x) dygraph_out, _ = layer(x)
self.check_save_inference_model(layer, [x_data], dygraph_out.numpy()) self.check_save_inference_model(layer, [x_data], dygraph_out.numpy())
...@@ -92,15 +90,12 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase): ...@@ -92,15 +90,12 @@ class TestDyToStaticSaveInferenceModel(unittest.TestCase):
expected_persistable_vars = set([p.name for p in model.parameters()]) expected_persistable_vars = set([p.name for p in model.parameters()])
infer_model_dir = "./test_dy2stat_save_inference_model" infer_model_dir = "./test_dy2stat_save_inference_model"
configs = fluid.dygraph.jit.SaveLoadConfig()
if fetch is not None:
configs.output_spec = fetch
configs.separate_params = True
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=model, layer=model,
model_path=infer_model_dir, model_path=infer_model_dir,
input_spec=feed if feed else None, input_spec=feed if feed else None,
configs=configs) separate_params=True,
output_spec=fetch if fetch else None)
saved_var_names = set([ saved_var_names = set([
filename for filename in os.listdir(infer_model_dir) filename for filename in os.listdir(infer_model_dir)
if filename != '__model__' and filename != EXTRA_VAR_INFO_FILENAME if filename != '__model__' and filename != EXTRA_VAR_INFO_FILENAME
......
...@@ -383,10 +383,10 @@ def train(train_reader, to_static): ...@@ -383,10 +383,10 @@ def train(train_reader, to_static):
step_idx += 1 step_idx += 1
if step_idx == STEP_NUM: if step_idx == STEP_NUM:
if to_static: if to_static:
configs = fluid.dygraph.jit.SaveLoadConfig() fluid.dygraph.jit.save(
configs.output_spec = [pred] se_resnext,
fluid.dygraph.jit.save(se_resnext, MODEL_SAVE_PATH, MODEL_SAVE_PATH, [img],
[img], configs) output_spec=[pred])
else: else:
fluid.dygraph.save_dygraph(se_resnext.state_dict(), fluid.dygraph.save_dygraph(se_resnext.state_dict(),
DY_STATE_DICT_SAVE_PATH) DY_STATE_DICT_SAVE_PATH)
......
...@@ -43,15 +43,14 @@ class TestDirectory(unittest.TestCase): ...@@ -43,15 +43,14 @@ class TestDirectory(unittest.TestCase):
'paddle.distributed.prepare_context', 'paddle.DataParallel', 'paddle.distributed.prepare_context', 'paddle.DataParallel',
'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static', 'paddle.jit', 'paddle.jit.TracedLayer', 'paddle.jit.to_static',
'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer', 'paddle.jit.ProgramTranslator', 'paddle.jit.TranslatedLayer',
'paddle.jit.save', 'paddle.jit.load', 'paddle.SaveLoadConfig', 'paddle.jit.save', 'paddle.jit.load', 'paddle.NoamDecay',
'paddle.NoamDecay', 'paddle.PiecewiseDecay', 'paddle.PiecewiseDecay', 'paddle.NaturalExpDecay',
'paddle.NaturalExpDecay', 'paddle.ExponentialDecay', 'paddle.ExponentialDecay', 'paddle.InverseTimeDecay',
'paddle.InverseTimeDecay', 'paddle.PolynomialDecay', 'paddle.PolynomialDecay', 'paddle.CosineDecay',
'paddle.CosineDecay', 'paddle.static.Executor', 'paddle.static.Executor', 'paddle.static.global_scope',
'paddle.static.global_scope', 'paddle.static.scope_guard', 'paddle.static.scope_guard', 'paddle.static.append_backward',
'paddle.static.append_backward', 'paddle.static.gradients', 'paddle.static.gradients', 'paddle.static.BuildStrategy',
'paddle.static.BuildStrategy', 'paddle.static.CompiledProgram', 'paddle.static.CompiledProgram', 'paddle.static.ExecutionStrategy',
'paddle.static.ExecutionStrategy',
'paddle.static.default_main_program', 'paddle.static.default_main_program',
'paddle.static.default_startup_program', 'paddle.static.Program', 'paddle.static.default_startup_program', 'paddle.static.Program',
'paddle.static.name_scope', 'paddle.static.program_guard', 'paddle.static.name_scope', 'paddle.static.program_guard',
...@@ -104,9 +103,7 @@ class TestDirectory(unittest.TestCase): ...@@ -104,9 +103,7 @@ class TestDirectory(unittest.TestCase):
'paddle.imperative.TracedLayer', 'paddle.imperative.declarative', 'paddle.imperative.TracedLayer', 'paddle.imperative.declarative',
'paddle.imperative.ProgramTranslator', 'paddle.imperative.ProgramTranslator',
'paddle.imperative.TranslatedLayer', 'paddle.imperative.jit.save', 'paddle.imperative.TranslatedLayer', 'paddle.imperative.jit.save',
'paddle.imperative.jit.load', 'paddle.imperative.jit.load', 'paddle.imperative.NoamDecay'
'paddle.imperative.jit.SaveLoadConfig',
'paddle.imperative.NoamDecay'
'paddle.imperative.PiecewiseDecay', 'paddle.imperative.PiecewiseDecay',
'paddle.imperative.NaturalExpDecay', 'paddle.imperative.NaturalExpDecay',
'paddle.imperative.ExponentialDecay', 'paddle.imperative.ExponentialDecay',
......
...@@ -225,16 +225,13 @@ class TestJitSaveLoad(unittest.TestCase): ...@@ -225,16 +225,13 @@ class TestJitSaveLoad(unittest.TestCase):
paddle.manual_seed(SEED) paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def train_and_save_model(self, model_path=None, configs=None): def train_and_save_model(self, model_path=None):
layer = LinearNet(784, 1) layer = LinearNet(784, 1)
example_inputs, layer, _ = train(layer) example_inputs, layer, _ = train(layer)
final_model_path = model_path if model_path else self.model_path final_model_path = model_path if model_path else self.model_path
orig_input_types = [type(x) for x in example_inputs] orig_input_types = [type(x) for x in example_inputs]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=layer, layer=layer, model_path=final_model_path, input_spec=example_inputs)
model_path=final_model_path,
input_spec=example_inputs,
configs=configs)
new_input_types = [type(x) for x in example_inputs] new_input_types = [type(x) for x in example_inputs]
self.assertEqual(orig_input_types, new_input_types) self.assertEqual(orig_input_types, new_input_types)
return layer return layer
...@@ -314,7 +311,6 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -314,7 +311,6 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
[None, 8], name='x')]) [None, 8], name='x')])
model_path = "model.input_spec.output_spec" model_path = "model.input_spec.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig()
# check inputs and outputs # check inputs and outputs
self.assertTrue(len(net.forward.inputs) == 1) self.assertTrue(len(net.forward.inputs) == 1)
input_x = net.forward.inputs[0] input_x = net.forward.inputs[0]
...@@ -322,11 +318,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -322,11 +318,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
self.assertTrue(input_x.name == 'x') self.assertTrue(input_x.name == 'x')
# 1. prune loss # 1. prune loss
configs.output_spec = net.forward.outputs[:1] output_spec = net.forward.outputs[:1]
fluid.dygraph.jit.save(net, model_path, configs=configs) fluid.dygraph.jit.save(net, model_path, output_spec=output_spec)
# 2. load to infer # 2. load to infer
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = fluid.dygraph.jit.load(model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
pred = infer_layer(x) pred = infer_layer(x)
...@@ -335,7 +331,6 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -335,7 +331,6 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
net = LinearNetMultiInput(8, 8) net = LinearNetMultiInput(8, 8)
model_path = "model.multi_inout.output_spec1" model_path = "model.multi_inout.output_spec1"
configs = fluid.dygraph.jit.SaveLoadConfig()
# 1. check inputs and outputs # 1. check inputs and outputs
self.assertTrue(len(net.forward.inputs) == 2) self.assertTrue(len(net.forward.inputs) == 2)
input_x = net.forward.inputs[0] input_x = net.forward.inputs[0]
...@@ -344,11 +339,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -344,11 +339,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
self.assertTrue(input_y.shape == (-1, 8)) self.assertTrue(input_y.shape == (-1, 8))
# 2. prune loss # 2. prune loss
configs.output_spec = net.forward.outputs[:2] output_spec = net.forward.outputs[:2]
fluid.dygraph.jit.save(net, model_path, configs=configs) fluid.dygraph.jit.save(net, model_path, output_spec=output_spec)
# 3. load to infer # 3. load to infer
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = fluid.dygraph.jit.load(model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
y = fluid.dygraph.to_variable( y = fluid.dygraph.to_variable(
...@@ -358,10 +353,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase): ...@@ -358,10 +353,11 @@ class TestSaveLoadWithInputSpec(unittest.TestCase):
# 1. prune y and loss # 1. prune y and loss
model_path = "model.multi_inout.output_spec2" model_path = "model.multi_inout.output_spec2"
configs.output_spec = net.forward.outputs[:1] output_spec = net.forward.outputs[:1]
fluid.dygraph.jit.save(net, model_path, [input_x], configs) fluid.dygraph.jit.save(
net, model_path, [input_x], output_spec=output_spec)
# 2. load again # 2. load again
infer_layer2 = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer2 = fluid.dygraph.jit.load(model_path)
# 3. predict # 3. predict
pred_xx = infer_layer2(x) pred_xx = infer_layer2(x)
...@@ -377,16 +373,16 @@ class TestJitSaveLoadConfig(unittest.TestCase): ...@@ -377,16 +373,16 @@ class TestJitSaveLoadConfig(unittest.TestCase):
paddle.manual_seed(SEED) paddle.manual_seed(SEED)
paddle.framework.random._manual_program_seed(SEED) paddle.framework.random._manual_program_seed(SEED)
def basic_save_load(self, layer, model_path, configs): def basic_save_load(self, layer, model_path, **configs):
# 1. train & save # 1. train & save
example_inputs, train_layer, _ = train(layer) example_inputs, train_layer, _ = train(layer)
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=train_layer, layer=train_layer,
model_path=model_path, model_path=model_path,
input_spec=example_inputs, input_spec=example_inputs,
configs=configs) **configs)
# 2. load # 2. load
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = fluid.dygraph.jit.load(model_path, **configs)
train_layer.eval() train_layer.eval()
# 3. inference & compare # 3. inference & compare
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
...@@ -397,23 +393,18 @@ class TestJitSaveLoadConfig(unittest.TestCase): ...@@ -397,23 +393,18 @@ class TestJitSaveLoadConfig(unittest.TestCase):
def test_model_filename(self): def test_model_filename(self):
layer = LinearNet(784, 1) layer = LinearNet(784, 1)
model_path = "model.save_load_config.output_spec" model_path = "model.save_load_config.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig()
configs.model_filename = "__simplenet__" self.basic_save_load(layer, model_path, model_filename="__simplenet__")
self.basic_save_load(layer, model_path, configs)
def test_params_filename(self): def test_params_filename(self):
layer = LinearNet(784, 1) layer = LinearNet(784, 1)
model_path = "model.save_load_config.params_filename" model_path = "model.save_load_config.params_filename"
configs = fluid.dygraph.jit.SaveLoadConfig() self.basic_save_load(layer, model_path, params_filename="__params__")
configs.params_filename = "__params__"
self.basic_save_load(layer, model_path, configs)
def test_separate_params(self): def test_separate_params(self):
layer = LinearNet(784, 1) layer = LinearNet(784, 1)
model_path = "model.save_load_config.separate_params" model_path = "model.save_load_config.separate_params"
configs = fluid.dygraph.jit.SaveLoadConfig() self.basic_save_load(layer, model_path, separate_params=True)
configs.separate_params = True
self.basic_save_load(layer, model_path, configs)
def test_output_spec(self): def test_output_spec(self):
train_layer = LinearNetReturnLoss(8, 8) train_layer = LinearNetReturnLoss(8, 8)
...@@ -428,16 +419,15 @@ class TestJitSaveLoadConfig(unittest.TestCase): ...@@ -428,16 +419,15 @@ class TestJitSaveLoadConfig(unittest.TestCase):
train_layer.clear_gradients() train_layer.clear_gradients()
model_path = "model.save_load_config.output_spec" model_path = "model.save_load_config.output_spec"
configs = fluid.dygraph.jit.SaveLoadConfig() output_spec = [out]
configs.output_spec = [out]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=train_layer, layer=train_layer,
model_path=model_path, model_path=model_path,
input_spec=[x], input_spec=[x],
configs=configs) output_spec=output_spec)
train_layer.eval() train_layer.eval()
infer_layer = fluid.dygraph.jit.load(model_path, configs=configs) infer_layer = fluid.dygraph.jit.load(model_path)
x = fluid.dygraph.to_variable( x = fluid.dygraph.to_variable(
np.random.random((4, 8)).astype('float32')) np.random.random((4, 8)).astype('float32'))
self.assertTrue( self.assertTrue(
...@@ -494,13 +484,12 @@ class TestJitPruneModelAndLoad(unittest.TestCase): ...@@ -494,13 +484,12 @@ class TestJitPruneModelAndLoad(unittest.TestCase):
adam.minimize(loss) adam.minimize(loss)
train_layer.clear_gradients() train_layer.clear_gradients()
configs = fluid.dygraph.jit.SaveLoadConfig() output_spec = [hidden]
configs.output_spec = [hidden]
fluid.dygraph.jit.save( fluid.dygraph.jit.save(
layer=train_layer, layer=train_layer,
model_path=self.model_path, model_path=self.model_path,
input_spec=[x], input_spec=[x],
configs=configs) output_spec=output_spec)
return train_layer return train_layer
...@@ -617,8 +606,6 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -617,8 +606,6 @@ class TestJitSaveMultiCases(unittest.TestCase):
out = train_with_label(layer) out = train_with_label(layer)
model_path = "test_prune_to_static_after_train" model_path = "test_prune_to_static_after_train"
configs = paddle.SaveLoadConfig()
configs.output_spec = [out]
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -626,7 +613,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -626,7 +613,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
InputSpec( InputSpec(
shape=[None, 784], dtype='float32', name="image") shape=[None, 784], dtype='float32', name="image")
], ],
configs=configs) output_spec=[out])
self.verify_inference_correctness(layer, model_path, True) self.verify_inference_correctness(layer, model_path, True)
...@@ -634,10 +621,9 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -634,10 +621,9 @@ class TestJitSaveMultiCases(unittest.TestCase):
layer = LinerNetWithLabel(784, 1) layer = LinerNetWithLabel(784, 1)
model_path = "test_prune_to_static_no_train" model_path = "test_prune_to_static_no_train"
configs = paddle.SaveLoadConfig()
# TODO: no train, cannot get output_spec var here # TODO: no train, cannot get output_spec var here
# now only can use index # now only can use index
configs.output_spec = layer.forward.outputs[:1] output_spec = layer.forward.outputs[:1]
paddle.jit.save( paddle.jit.save(
layer, layer,
model_path, model_path,
...@@ -645,7 +631,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -645,7 +631,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
InputSpec( InputSpec(
shape=[None, 784], dtype='float32', name="image") shape=[None, 784], dtype='float32', name="image")
], ],
configs=configs) output_spec=output_spec)
self.verify_inference_correctness(layer, model_path, True) self.verify_inference_correctness(layer, model_path, True)
...@@ -676,10 +662,8 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -676,10 +662,8 @@ class TestJitSaveMultiCases(unittest.TestCase):
train(layer) train(layer)
model_path = "test_not_prune_output_spec_name_warning" model_path = "test_not_prune_output_spec_name_warning"
configs = paddle.SaveLoadConfig()
out = paddle.to_tensor(np.random.random((1, 1)).astype('float')) out = paddle.to_tensor(np.random.random((1, 1)).astype('float'))
configs.output_spec = [out] paddle.jit.save(layer, model_path, output_spec=[out])
paddle.jit.save(layer, model_path, configs=configs)
self.verify_inference_correctness(layer, model_path) self.verify_inference_correctness(layer, model_path)
...@@ -708,9 +692,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -708,9 +692,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
train_with_label(layer) train_with_label(layer)
model_path = "test_prune_to_static_after_train" model_path = "test_prune_to_static_after_train"
configs = paddle.SaveLoadConfig()
out = paddle.to_tensor(np.random.random((1, 1)).astype('float')) out = paddle.to_tensor(np.random.random((1, 1)).astype('float'))
configs.output_spec = [out]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
paddle.jit.save( paddle.jit.save(
layer, layer,
...@@ -719,7 +701,7 @@ class TestJitSaveMultiCases(unittest.TestCase): ...@@ -719,7 +701,7 @@ class TestJitSaveMultiCases(unittest.TestCase):
InputSpec( InputSpec(
shape=[None, 784], dtype='float32', name="image") shape=[None, 784], dtype='float32', name="image")
], ],
configs=configs) output_spec=[out])
class TestJitSaveLoadEmptyLayer(unittest.TestCase): class TestJitSaveLoadEmptyLayer(unittest.TestCase):
......
...@@ -63,6 +63,8 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -63,6 +63,8 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.epoch_num = 1 self.epoch_num = 1
self.batch_size = 128 self.batch_size = 128
self.batch_num = 10 self.batch_num = 10
# enable static mode
paddle.enable_static()
def train_and_save_model(self, only_params=False): def train_and_save_model(self, only_params=False):
with new_program_scope(): with new_program_scope():
...@@ -136,13 +138,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -136,13 +138,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = None self.params_filename = None
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
config = paddle.SaveLoadConfig() load_param_dict, _ = fluid.load_dygraph(
config.separate_params = True self.save_dirname, model_filename=self.model_filename)
config.model_filename = self.model_filename
load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname, config) new_load_param_dict = paddle.load(
self.save_dirname, model_filename=self.model_filename)
self.check_load_state_dict(orig_param_dict, new_load_param_dict) self.check_load_state_dict(orig_param_dict, new_load_param_dict)
def test_load_with_param_filename(self): def test_load_with_param_filename(self):
...@@ -151,12 +152,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -151,12 +152,12 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = "static_mnist.params" self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
config = paddle.SaveLoadConfig() load_param_dict, _ = fluid.load_dygraph(
config.params_filename = self.params_filename self.save_dirname, params_filename=self.params_filename)
load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname, config) new_load_param_dict = paddle.load(
self.save_dirname, params_filename=self.params_filename)
self.check_load_state_dict(orig_param_dict, new_load_param_dict) self.check_load_state_dict(orig_param_dict, new_load_param_dict)
def test_load_with_model_and_param_filename(self): def test_load_with_model_and_param_filename(self):
...@@ -165,13 +166,16 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): ...@@ -165,13 +166,16 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase):
self.params_filename = "static_mnist.params" self.params_filename = "static_mnist.params"
orig_param_dict = self.train_and_save_model() orig_param_dict = self.train_and_save_model()
config = paddle.SaveLoadConfig() load_param_dict, _ = fluid.load_dygraph(
config.params_filename = self.params_filename self.save_dirname,
config.model_filename = self.model_filename params_filename=self.params_filename,
load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config) model_filename=self.model_filename)
self.check_load_state_dict(orig_param_dict, load_param_dict) self.check_load_state_dict(orig_param_dict, load_param_dict)
new_load_param_dict = paddle.load(self.save_dirname, config) new_load_param_dict = paddle.load(
self.save_dirname,
params_filename=self.params_filename,
model_filename=self.model_filename)
self.check_load_state_dict(orig_param_dict, new_load_param_dict) self.check_load_state_dict(orig_param_dict, new_load_param_dict)
def test_load_state_dict_from_save_params(self): def test_load_state_dict_from_save_params(self):
......
...@@ -20,8 +20,8 @@ __all__ = [ ...@@ -20,8 +20,8 @@ __all__ = [
] ]
__all__ += [ __all__ += [
'grad', 'LayerList', 'load', 'save', 'SaveLoadConfig', 'to_variable', 'grad', 'LayerList', 'load', 'save', 'to_variable', 'no_grad',
'no_grad', 'DataParallel' 'DataParallel'
] ]
__all__ += [ __all__ += [
...@@ -50,7 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS ...@@ -50,7 +50,6 @@ from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS
from ..fluid.dygraph.base import grad #DEFINE_ALIAS from ..fluid.dygraph.base import grad #DEFINE_ALIAS
from .io import save from .io import save
from .io import load from .io import load
from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS
from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS
from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS from ..fluid.dygraph.learning_rate_scheduler import NoamDecay #DEFINE_ALIAS
......
...@@ -26,6 +26,7 @@ import paddle ...@@ -26,6 +26,7 @@ import paddle
from paddle import fluid from paddle import fluid
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer
from paddle.fluid.dygraph.jit import _SaveLoadConfig
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME
__all__ = [ __all__ = [
...@@ -116,6 +117,29 @@ def _load_state_dict_from_save_params(model_path): ...@@ -116,6 +117,29 @@ def _load_state_dict_from_save_params(model_path):
return load_param_dict return load_param_dict
def _parse_load_config(configs):
supported_configs = [
'model_filename', 'params_filename', 'separate_params',
'keep_name_table'
]
# input check
for key in configs:
if key not in supported_configs:
raise ValueError(
"The additional config (%s) of `paddle.load` is not supported."
% key)
# construct inner config
inner_config = _SaveLoadConfig()
inner_config.model_filename = configs.get('model_filename', None)
inner_config.params_filename = configs.get('params_filename', None)
inner_config.separate_params = configs.get('separate_params', None)
inner_config.keep_name_table = configs.get('keep_name_table', None)
return inner_config
def save(obj, path): def save(obj, path):
''' '''
Save an object to the specified path. Save an object to the specified path.
...@@ -178,7 +202,7 @@ def save(obj, path): ...@@ -178,7 +202,7 @@ def save(obj, path):
pickle.dump(saved_obj, f, protocol=2) pickle.dump(saved_obj, f, protocol=2)
def load(path, config=None): def load(path, **configs):
''' '''
Load an object can be used in paddle from specified path. Load an object can be used in paddle from specified path.
...@@ -197,10 +221,20 @@ def load(path, config=None): ...@@ -197,10 +221,20 @@ def load(path, config=None):
path(str) : The path to load the target object. Generally, the path is the target path(str) : The path to load the target object. Generally, the path is the target
file path, when compatible with loading the saved results of file path, when compatible with loading the saved results of
``paddle.jit.save/paddle.static.save_inference_model`` , the path is a directory. ``paddle.jit.save/paddle.static.save_inference_model`` , the path is a directory.
config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` configs (dict, optional): other save configuration options for compatibility. We do not
object that specifies additional configuration options, these options recommend using these configurations, if not necessary, DO NOT use them. Default None.
are for compatibility with ``paddle.jit.save/paddle.static.save_inference_model`` The following options are currently supported:
formats. Default None. (1) model_filename (string): The filename to load the translated program of target Layer.
Default filename is :code:`__model__` .
(2) params_filename (string): The filename to load all persistable variables in target Layer.
Default file name is :code:`__variables__` .
(3) separate_params (bool): Configure whether to load the Layer parameters from separete files.
If True, each parameter will be loaded from a file separately, the file name is the parameter name,
and the params_filename configuration will not take effect. Default False.
(4) keep_name_table (bool): Configures whether keep ``structured_name -> parameter_name`` dict in
loaded state dict. This dict is the debugging information saved when call ``paddle.save`` .
It is generally only used for debugging and does not affect the actual training or inference.
By default, it will not be retained in ``paddle.load`` result. Default: False.
Returns: Returns:
Object(Object): a target object can be used in paddle Object(Object): a target object can be used in paddle
...@@ -242,8 +276,7 @@ def load(path, config=None): ...@@ -242,8 +276,7 @@ def load(path, config=None):
"`paddle.load('model')`." "`paddle.load('model')`."
raise ValueError(error_msg % path) raise ValueError(error_msg % path)
if config is None: config = _parse_load_config(configs)
config = paddle.SaveLoadConfig()
# 2. load target # 2. load target
load_result = None load_result = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册