From bb84f0e64612ad4b6899c61aab2d3e97b1177b27 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Wed, 23 Sep 2020 20:12:46 +0800 Subject: [PATCH] Add new paddle.save/load APIs (#27331) * init commit of new save/load * fix failed unittests * fix save_load_v2 unittest failed * fix failed unittest & polish doc * add tests for coverage * add more tests & move static apis * fix example code error * polish emample code * fix detail example code problem --- python/paddle/fluid/dygraph/checkpoint.py | 18 +- python/paddle/fluid/dygraph/jit.py | 27 +- python/paddle/fluid/dygraph/layers.py | 6 +- python/paddle/fluid/dygraph/parallel.py | 6 +- python/paddle/fluid/optimizer.py | 20 +- .../unittests/test_imperative_save_load.py | 18 +- .../unittests/test_imperative_save_load_v2.py | 29 +- .../test_load_state_dict_from_old_format.py | 41 ++- .../tests/unittests/test_paddle_save_load.py | 148 +++++++++ python/paddle/framework/__init__.py | 4 +- python/paddle/framework/io.py | 291 ++++++++++++++++++ python/paddle/io/__init__.py | 8 - python/paddle/static/__init__.py | 13 +- python/paddle/tensor/__init__.py | 2 - python/paddle/tensor/io.py | 19 -- 15 files changed, 539 insertions(+), 111 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_paddle_save_load.py create mode 100644 python/paddle/framework/io.py delete mode 100644 python/paddle/tensor/io.py diff --git a/python/paddle/fluid/dygraph/checkpoint.py b/python/paddle/fluid/dygraph/checkpoint.py index 93cb0bafc8..f4ea4d670e 100644 --- a/python/paddle/fluid/dygraph/checkpoint.py +++ b/python/paddle/fluid/dygraph/checkpoint.py @@ -145,7 +145,7 @@ def load_dygraph(model_path, config=None): .. note:: Due to some historical reasons, if you load ``state_dict`` from the saved - result of `paddle.io.save_inference_model`, the structured variable name + result of `paddle.static.save_inference_model`, the structured variable name will cannot be restored. You need to set the argument `use_structured_name=False` when using `Layer.set_state_dict` later. @@ -164,24 +164,24 @@ def load_dygraph(model_path, config=None): .. code-block:: python import paddle - + import paddle.fluid as fluid + paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) state_dict = emb.state_dict() - paddle.save(state_dict, "paddle_dy") + fluid.save_dygraph(state_dict, "paddle_dy") - scheduler = paddle.optimizer.lr_scheduler.NoamLR( + scheduler = paddle.optimizer.lr_scheduler.NoamLR( d_model=0.01, warmup_steps=100, verbose=True) adam = paddle.optimizer.Adam( learning_rate=scheduler, parameters=emb.parameters()) state_dict = adam.state_dict() - paddle.save(state_dict, "paddle_dy") - - para_state_dict, opti_state_dict = paddle.load("paddle_dy") + fluid.save_dygraph(state_dict, "paddle_dy") + para_state_dict, opti_state_dict = fluid.load_dygraph("paddle_dy") ''' # deal with argument `model_path` model_prefix = model_path @@ -275,7 +275,7 @@ def load_dygraph(model_path, config=None): # If users save all parameters as one file, the [ variable.name -> variable ] # mapping info will lost, so users need to give variable list, but users build # variable list in dygraph mode is difficult, we recommend users to use - # paddle.io.load_program_state in this case + # paddle.static.load_program_state in this case # Try to load all the files in the directory in VarBase format, # the file name is used as the name of VarBase diff --git a/python/paddle/fluid/dygraph/jit.py b/python/paddle/fluid/dygraph/jit.py index 10819e4b32..d0e3d23b04 100644 --- a/python/paddle/fluid/dygraph/jit.py +++ b/python/paddle/fluid/dygraph/jit.py @@ -231,9 +231,7 @@ def declarative(function=None, input_spec=None): class SaveLoadConfig(object): """ The additional configuration options may be used in function - :ref:`api_imperative_jit_save` that save :ref:`api_imperative_TranslatedLayer` - or used in function :ref:`api_imperative_jit_load` that - load :ref:`api_imperative_TranslatedLayer` . + ``paddle.jit.save/load`` and ``paddle.load`` . Examples: 1. Using ``SaveLoadConfig`` when saving model @@ -319,7 +317,7 @@ class SaveLoadConfig(object): @property def output_spec(self): """ - Selects the output targets of the saved model ( :ref:`api_imperative_TranslatedLayer` ). + Selects the output targets of the saved model ( ``paddle.jit.TranslatedLayer`` ). By default, all return variables of original Layer's forward function are kept as the output of the saved TranslatedLayer. @@ -531,11 +529,14 @@ class SaveLoadConfig(object): def separate_params(self): """ Configure whether to save the Layer parameters as separete files. - (In order to be compatible with the behavior of :ref:`api_fluid_io_save_inference_model` ) + (In order to be compatible with the behavior of ``paddle.static.save_inference_model`` ) If True, each parameter will be saved to a file separately, the file name is the parameter name, and the SaveLoadConfig.params_filename configuration will not take effect. Default False. + .. note:: + Only used for ``paddle.jit.save`` . + Examples: .. code-block:: python @@ -569,7 +570,7 @@ class SaveLoadConfig(object): adam.clear_grad() model_path = "simplenet.example.model.separate_params" - config = paddle.jit.SaveLoadConfig() + config = paddle.SaveLoadConfig() config.separate_params = True # saving with configs.separate_params @@ -599,12 +600,12 @@ class SaveLoadConfig(object): def keep_name_table(self): """ Configures whether keep ``structured_name -> parameter_name`` dict in loaded state dict. - This dict is the debugging information saved when call `paddle.save`. + This dict is the debugging information saved when call ``paddle.save`` . It is generally only used for debugging and does not affect the actual training or inference. - By default, it will not be retained in `paddle.load` result. Default: False. + By default, it will not be retained in ``paddle.load`` result. Default: False. .. note:: - Only used for ``paddle.load``. + Only used for ``paddle.load`` . Examples: .. code-block:: python @@ -616,11 +617,11 @@ class SaveLoadConfig(object): linear = paddle.nn.Linear(5, 1) state_dict = linear.state_dict() - paddle.save(state_dict, "paddle_dy") + paddle.save(state_dict, "paddle_dy.pdparams") - configs = paddle.SaveLoadConfig() - configs.keep_name_table = True - para_state_dict, _ = paddle.load("paddle_dy", configs) + config = paddle.SaveLoadConfig() + config.keep_name_table = True + para_state_dict = paddle.load("paddle_dy.pdparams", config) print(para_state_dict) # the name_table is 'StructuredToParameterName@@' diff --git a/python/paddle/fluid/dygraph/layers.py b/python/paddle/fluid/dygraph/layers.py index 9c79deaab7..88e24e7e1e 100644 --- a/python/paddle/fluid/dygraph/layers.py +++ b/python/paddle/fluid/dygraph/layers.py @@ -970,12 +970,12 @@ class Layer(core.Layer): paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) state_dict = emb.state_dict() - paddle.save(state_dict, "paddle_dy") + paddle.save(state_dict, "paddle_dy.pdparams") - para_state_dict, _ = paddle.load("paddle_dy") + para_state_dict = paddle.load("paddle_dy.pdparams") emb.set_state_dict(para_state_dict) diff --git a/python/paddle/fluid/dygraph/parallel.py b/python/paddle/fluid/dygraph/parallel.py index 472022bced..de761cad52 100644 --- a/python/paddle/fluid/dygraph/parallel.py +++ b/python/paddle/fluid/dygraph/parallel.py @@ -610,13 +610,13 @@ class DataParallel(layers.Layer): paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) emb = fluid.dygraph.DataParallel(emb, strategy) state_dict = emb.state_dict() - paddle.save(state_dict, "paddle_dy") + paddle.save(state_dict, "paddle_dy.pdparams") - para_state_dict, _ = paddle.load("paddle_dy") + para_state_dict = paddle.load("paddle_dy.pdparams") emb.set_state_dict(para_state_dict) diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 1e7915ed78..0dd1694c86 100755 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -182,23 +182,25 @@ class Optimizer(object): Examples: .. code-block:: python - import paddle + import paddle + import paddle.fluid as fluid paddle.disable_static() - emb = paddle.nn.Embedding([10, 10]) + emb = paddle.nn.Embedding(10, 10) state_dict = emb.state_dict() - paddle.save(state_dict, "paddle_dy") + fluid.save_dygraph(state_dict, "paddle_dy") - adam = paddle.optimizer.Adam(learning_rate=fluid.layers.noam_decay( 100, 10000), - parameter_list=emb.parameters()) + scheduler = paddle.optimizer.lr_scheduler.NoamLR( + d_model=0.01, warmup_steps=100, verbose=True) + adam = paddle.optimizer.Adam( + learning_rate=scheduler, + parameters=emb.parameters()) state_dict = adam.state_dict() + fluid.save_dygraph(state_dict, "paddle_dy") - para_state_dict, opti_state_dict = paddle.load("paddle_dy") - - adam.set_state_dict(opti_state_dict) - + para_state_dict, opti_state_dict = fluid.load_dygraph("paddle_dy") ''' from paddle.optimizer.lr_scheduler import _LRScheduler if isinstance(self._learning_rate, _LRScheduler): diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py index 22e19efcb5..bee53fd10f 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load.py @@ -292,7 +292,7 @@ class TestDygraphPtbRnn(unittest.TestCase): np_t = v.numpy() self.model_base[k] = np_t - paddle.save(self.state_dict, "./test_dy") + fluid.save_dygraph(self.state_dict, "./test_dy") def testLoadAndSetVarBase(self): seed = 90 @@ -373,7 +373,7 @@ class TestDygraphPtbRnn(unittest.TestCase): if isinstance(adam._learning_rate, LearningRateDecay): adam._learning_rate.step_num = 0 - para_state_dict, opti_state_dict = paddle.load("./test_dy") + para_state_dict, opti_state_dict = fluid.load_dygraph("./test_dy") adam.set_state_dict(opti_state_dict) opti_dict = adam.state_dict() @@ -898,31 +898,31 @@ class TestDygraphPtbRnn(unittest.TestCase): with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding([10, 10]) state_dict = emb.state_dict() - paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + fluid.save_dygraph(state_dict, os.path.join('saved_dy', 'emb_dy')) - para_state_dict, opti_state_dict = paddle.load( + para_state_dict, opti_state_dict = fluid.load_dygraph( os.path.join('saved_dy', 'emb_dy')) self.assertTrue(opti_state_dict == None) - para_state_dict, opti_state_dict = paddle.load( + para_state_dict, opti_state_dict = fluid.load_dygraph( os.path.join('saved_dy', 'emb_dy.pdparams')) - para_state_dict, opti_state_dict = paddle.load( + para_state_dict, opti_state_dict = fluid.load_dygraph( os.path.join('saved_dy', 'emb_dy.pdopt')) def test_load_compatible_with_keep_name_table(self): with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding([10, 10]) state_dict = emb.state_dict() - paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + fluid.save_dygraph(state_dict, os.path.join('saved_dy', 'emb_dy')) - para_state_dict, opti_state_dict = paddle.load( + para_state_dict, opti_state_dict = fluid.load_dygraph( os.path.join('saved_dy', 'emb_dy'), True) self.assertTrue(para_state_dict != None) self.assertTrue(opti_state_dict == None) - para_state_dict, opti_state_dict = paddle.load( + para_state_dict, opti_state_dict = fluid.load_dygraph( os.path.join('saved_dy', 'emb_dy'), keep_name_table=True) self.assertTrue(para_state_dict != None) self.assertTrue(opti_state_dict == None) diff --git a/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py b/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py index 3eb413a626..5b7998198e 100644 --- a/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py +++ b/python/paddle/fluid/tests/unittests/test_imperative_save_load_v2.py @@ -285,7 +285,7 @@ class TestDygraphPtbRnn(unittest.TestCase): else: self.base_opti[k] = v - fluid.save_dygraph(self.opti_dict, "./test_dy_v2") + paddle.save(self.opti_dict, "./test_dy_v2.pdopt") self.state_dict = ptb_model.state_dict() @@ -294,7 +294,7 @@ class TestDygraphPtbRnn(unittest.TestCase): np_t = v.numpy() self.model_base[k] = np_t - paddle.save(self.state_dict, "./test_dy_v2") + paddle.save(self.state_dict, "./test_dy_v2.pdparams") def testLoadAndSetVarBase(self): self.setUp() @@ -374,7 +374,8 @@ class TestDygraphPtbRnn(unittest.TestCase): self.assertTrue(np.sum(np.abs(v.numpy())) == 0) - para_state_dict, opti_state_dict = paddle.load("./test_dy_v2") + para_state_dict = paddle.load("./test_dy_v2.pdparams") + opti_state_dict = paddle.load("./test_dy_v2.pdopt") adam.set_state_dict(opti_state_dict) opti_dict = adam.state_dict() @@ -905,26 +906,19 @@ class TestDygraphPtbRnn(unittest.TestCase): with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding([10, 10]) state_dict = emb.state_dict() - paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy.pdparams')) - para_state_dict, opti_state_dict = paddle.load( - os.path.join('saved_dy', 'emb_dy')) - - self.assertTrue(opti_state_dict == None) - - para_state_dict, opti_state_dict = paddle.load( + para_state_dict = paddle.load( os.path.join('saved_dy', 'emb_dy.pdparams')) - para_state_dict, opti_state_dict = paddle.load( - os.path.join('saved_dy', 'emb_dy.pdopt')) - def test_no_state_in_input_dict(self): with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding([10, 10]) state_dict = emb.state_dict() - paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy.pdparams')) - para_state_dict, _ = paddle.load(os.path.join('saved_dy', 'emb_dy')) + para_state_dict = paddle.load( + os.path.join('saved_dy', 'emb_dy.pdparams')) para_state_dict.pop('weight') emb.set_state_dict(para_state_dict) @@ -933,9 +927,10 @@ class TestDygraphPtbRnn(unittest.TestCase): with fluid.dygraph.guard(): emb = fluid.dygraph.Embedding([10, 10]) state_dict = emb.state_dict() - paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy')) + paddle.save(state_dict, os.path.join('saved_dy', 'emb_dy.pdparams')) - para_state_dict, _ = paddle.load(os.path.join('saved_dy', 'emb_dy')) + para_state_dict = paddle.load( + os.path.join('saved_dy', 'emb_dy.pdparams')) para_state_dict['weight'] = np.expand_dims( para_state_dict['weight'], axis=-1) diff --git a/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py b/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py index a1a9b3f444..fdc1e6b52a 100644 --- a/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py +++ b/python/paddle/fluid/tests/unittests/test_load_state_dict_from_old_format.py @@ -124,52 +124,67 @@ class TestLoadStateDictFromSaveInferenceModel(unittest.TestCase): self.params_filename = None orig_param_dict = self.train_and_save_model() - load_param_dict, _ = paddle.load(self.save_dirname) + load_param_dict, _ = fluid.load_dygraph(self.save_dirname) self.check_load_state_dict(orig_param_dict, load_param_dict) + new_load_param_dict = paddle.load(self.save_dirname) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) + def test_load_with_model_filename(self): self.save_dirname = "static_mnist.load_state_dict.model_filename" self.model_filename = "static_mnist.model" self.params_filename = None orig_param_dict = self.train_and_save_model() - configs = paddle.SaveLoadConfig() - configs.separate_params = True - configs.model_filename = self.model_filename - load_param_dict, _ = paddle.load(self.save_dirname, configs) + config = paddle.SaveLoadConfig() + config.separate_params = True + config.model_filename = self.model_filename + load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config) self.check_load_state_dict(orig_param_dict, load_param_dict) + new_load_param_dict = paddle.load(self.save_dirname, config) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) + def test_load_with_param_filename(self): self.save_dirname = "static_mnist.load_state_dict.param_filename" self.model_filename = None self.params_filename = "static_mnist.params" orig_param_dict = self.train_and_save_model() - configs = paddle.SaveLoadConfig() - configs.params_filename = self.params_filename - load_param_dict, _ = paddle.load(self.save_dirname, configs) + config = paddle.SaveLoadConfig() + config.params_filename = self.params_filename + load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config) self.check_load_state_dict(orig_param_dict, load_param_dict) + new_load_param_dict = paddle.load(self.save_dirname, config) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) + def test_load_with_model_and_param_filename(self): self.save_dirname = "static_mnist.load_state_dict.model_and_param_filename" self.model_filename = "static_mnist.model" self.params_filename = "static_mnist.params" orig_param_dict = self.train_and_save_model() - configs = paddle.SaveLoadConfig() - configs.params_filename = self.params_filename - configs.model_filename = self.model_filename - load_param_dict, _ = paddle.load(self.save_dirname, configs) + config = paddle.SaveLoadConfig() + config.params_filename = self.params_filename + config.model_filename = self.model_filename + load_param_dict, _ = fluid.load_dygraph(self.save_dirname, config) self.check_load_state_dict(orig_param_dict, load_param_dict) + new_load_param_dict = paddle.load(self.save_dirname, config) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) + def test_load_state_dict_from_save_params(self): self.save_dirname = "static_mnist.load_state_dict.save_params" self.params_filename = None orig_param_dict = self.train_and_save_model(True) - load_param_dict, _ = paddle.load(self.save_dirname) + load_param_dict, _ = fluid.load_dygraph(self.save_dirname) self.check_load_state_dict(orig_param_dict, load_param_dict) + new_load_param_dict = paddle.load(self.save_dirname) + self.check_load_state_dict(orig_param_dict, new_load_param_dict) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_paddle_save_load.py b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py new file mode 100644 index 0000000000..74d44d0f8b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_paddle_save_load.py @@ -0,0 +1,148 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.nn as nn +import paddle.optimizer as opt + +BATCH_SIZE = 16 +BATCH_NUM = 4 +EPOCH_NUM = 4 +SEED = 10 + +IMAGE_SIZE = 784 +CLASS_NUM = 10 + + +# define a random dataset +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples): + self.num_samples = num_samples + + def __getitem__(self, idx): + np.random.seed(SEED) + image = np.random.random([IMAGE_SIZE]).astype('float32') + label = np.random.randint(0, CLASS_NUM - 1, (1, )).astype('int64') + return image, label + + def __len__(self): + return self.num_samples + + +class LinearNet(nn.Layer): + def __init__(self): + super(LinearNet, self).__init__() + self._linear = nn.Linear(IMAGE_SIZE, CLASS_NUM) + + def forward(self, x): + return self._linear(x) + + +def train(layer, loader, loss_fn, opt): + for epoch_id in range(EPOCH_NUM): + for batch_id, (image, label) in enumerate(loader()): + out = layer(image) + loss = loss_fn(out, label) + loss.backward() + opt.step() + opt.clear_grad() + + +class TestSaveLoad(unittest.TestCase): + def setUp(self): + # enable dygraph mode + self.place = paddle.CPUPlace() + paddle.disable_static(self.place) + + # config seed + paddle.manual_seed(SEED) + paddle.framework.random._manual_program_seed(SEED) + + def build_and_train_model(self): + # create network + layer = LinearNet() + loss_fn = nn.CrossEntropyLoss() + + adam = opt.Adam(learning_rate=0.001, parameters=layer.parameters()) + + # create data loader + dataset = RandomDataset(BATCH_NUM * BATCH_SIZE) + loader = paddle.io.DataLoader( + dataset, + places=self.place, + batch_size=BATCH_SIZE, + shuffle=True, + drop_last=True, + num_workers=2) + + # train + train(layer, loader, loss_fn, adam) + + return layer, adam + + def check_load_state_dict(self, orig_dict, load_dict): + for var_name, value in orig_dict.items(): + self.assertTrue(np.array_equal(value.numpy(), load_dict[var_name])) + + def test_save_load(self): + layer, opt = self.build_and_train_model() + + # save + layer_save_path = "linear.pdparams" + opt_save_path = "linear.pdopt" + layer_state_dict = layer.state_dict() + opt_state_dict = opt.state_dict() + + paddle.save(layer_state_dict, layer_save_path) + paddle.save(opt_state_dict, opt_save_path) + + # load + load_layer_state_dict = paddle.load(layer_save_path) + load_opt_state_dict = paddle.load(opt_save_path) + + self.check_load_state_dict(layer_state_dict, load_layer_state_dict) + self.check_load_state_dict(opt_state_dict, load_opt_state_dict) + + # test save load in static mode + paddle.enable_static() + static_save_path = "static_mode_test/linear.pdparams" + paddle.save(layer_state_dict, static_save_path) + load_static_state_dict = paddle.load(static_save_path) + self.check_load_state_dict(layer_state_dict, load_static_state_dict) + + # error test cases, some tests relay base test above + # 1. test save obj not dict error + test_list = [1, 2, 3] + with self.assertRaises(NotImplementedError): + paddle.save(test_list, "not_dict_error_path") + + # 2. test save path format error + with self.assertRaises(ValueError): + paddle.save(layer_state_dict, "linear.model/") + + # 3. test load path not exist error + with self.assertRaises(ValueError): + paddle.load("linear.params") + + # 4. test load old save path error + with self.assertRaises(ValueError): + paddle.load("linear") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/framework/__init__.py b/python/paddle/framework/__init__.py index f33e4e0fca..2ce442add2 100644 --- a/python/paddle/framework/__init__.py +++ b/python/paddle/framework/__init__.py @@ -48,8 +48,8 @@ from paddle.fluid import core #DEFINE_ALIAS from ..fluid.dygraph.base import no_grad #DEFINE_ALIAS from ..fluid.dygraph.base import to_variable #DEFINE_ALIAS from ..fluid.dygraph.base import grad #DEFINE_ALIAS -from ..fluid.dygraph.checkpoint import load_dygraph as load #DEFINE_ALIAS -from ..fluid.dygraph.checkpoint import save_dygraph as save #DEFINE_ALIAS +from .io import save +from .io import load from ..fluid.dygraph.jit import SaveLoadConfig #DEFINE_ALIAS from ..fluid.dygraph.parallel import DataParallel #DEFINE_ALIAS diff --git a/python/paddle/framework/io.py b/python/paddle/framework/io.py new file mode 100644 index 0000000000..7175f31014 --- /dev/null +++ b/python/paddle/framework/io.py @@ -0,0 +1,291 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import os +import collections +import pickle +import six +import warnings + +import paddle + +# deprecated module import +from paddle import fluid +from paddle.fluid import core +from paddle.fluid.framework import Variable, _varbase_creator, _dygraph_tracer +from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers, EXTRA_VAR_INFO_FILENAME + +__all__ = [ + 'save', + 'load', +] + + +def _build_saved_state_dict(state_dict): + save_dict = {} + name_table = {} + for key, value in state_dict.items(): + if isinstance(value, (Variable, core.VarBase)): + save_dict[key] = value.numpy() + name_table[key] = value.name + else: + save_dict[key] = value + save_dict["StructuredToParameterName@@"] = name_table + + return save_dict + + +def _load_state_dict_from_save_inference_model(model_path, config): + # 1. load program desc & construct _ProgramHolder + programs = _construct_program_holders(model_path, config.model_filename) + + # 2. load layer parameters & buffers + with fluid.dygraph.guard(): + persistable_var_dict = _construct_params_and_buffers( + model_path, + programs, + config.separate_params, + config.params_filename, + append_suffix=False) + + # 3. construct state_dict + load_param_dict = dict() + for var_name in persistable_var_dict: + load_param_dict[var_name] = persistable_var_dict[var_name].numpy() + + # if __variables.info__ exists, we can recover structured_name + var_info_path = os.path.join(model_path, EXTRA_VAR_INFO_FILENAME) + if os.path.exists(var_info_path): + with open(var_info_path, 'rb') as f: + extra_var_info = pickle.load(f) + structured_para_dict = dict() + for var_name in load_param_dict: + structured_name = extra_var_info[var_name].get( + 'structured_name', None) + assert structured_name is not None, "Cannot find saved variable (%s)'s structured name in saved model." % var_name + structured_para_dict[structured_name] = load_param_dict[ + var_name] + load_param_dict = structured_para_dict + + return load_param_dict + + +def _load_state_dict_from_save_params(model_path): + # Try to load all the files in the directory in VarBase format, + # the file name is used as the name of VarBase + load_var_list = [] + + # 1. load file names + var_name_list = [] + for root, _, files in os.walk(model_path): + for filename in files: + file_path = os.path.join(root, filename) + tmp_var_name = os.path.relpath(file_path, model_path) + var_name = tmp_var_name.replace("\\", "/") + var_name_list.append(var_name) + + # 2. create and load VarBase + with fluid.dygraph.guard(): + for name in var_name_list: + new_var = _varbase_creator(name=name, persistable=True) + _dygraph_tracer().trace_op( + type='load', + inputs={}, + outputs={'Out': new_var}, + attrs={'file_path': os.path.join(model_path, name)}) + load_var_list.append(new_var) + + # 3. construct state_dict + load_param_dict = dict() + for var in load_var_list: + load_param_dict[var.name] = var.numpy() + + return load_param_dict + + +def save(obj, path): + ''' + Save an object to the specified path. + + .. note:: + Now only supports save ``state_dict`` of Layer or Optimizer. + + Args: + obj(Object) : The object to be saved. + path(str) : The path of the object to be saved. + If saved in the current directory, the input path string will be used as the file name. + + Returns: + None + + Examples: + .. code-block:: python + + import paddle + + paddle.disable_static() + + emb = paddle.nn.Embedding(10, 10) + layer_state_dict = emb.state_dict() + paddle.save(layer_state_dict, "emb.pdparams") + + scheduler = paddle.optimizer.lr_scheduler.NoamLR( + d_model=0.01, warmup_steps=100, verbose=True) + adam = paddle.optimizer.Adam( + learning_rate=scheduler, + parameters=emb.parameters()) + opt_state_dict = adam.state_dict() + paddle.save(opt_state_dict, "adam.pdopt") + ''' + + # 1. input check + if not isinstance(obj, dict): + raise NotImplementedError( + "Now only supports save state_dict of Layer or Optimizer, " + "expect dict, but received %s." % type(obj)) + + if len(obj) == 0: + warnings.warn("The input state dict is empty, no need to save.") + + filename = os.path.basename(path) + if filename == "": + raise ValueError("The input path MUST be format of dirname/filename " + "[dirname\\filename in Windows system], but received " + "filename is empty string.") + + # 2. save object + dirname = os.path.dirname(path) + if dirname and not os.path.exists(dirname): + os.makedirs(dirname) + + # TODO(chenweihang): supports save other object + saved_obj = _build_saved_state_dict(obj) + + with open(path, 'wb') as f: + pickle.dump(saved_obj, f, protocol=2) + + +def load(path, config=None): + ''' + Load an object can be used in paddle from specified path. + + .. note:: + Now only supports load ``state_dict`` of Layer or Optimizer. + + .. note:: + ``paddle.load`` supports loading ``state_dict`` from the result of several + paddle1.x save APIs in static mode, but due to some historical reasons, + if you load ``state_dict`` from the saved result of + ``paddle.static.save_inference_model/paddle.fluid.io.save_params/paddle.fluid.io.save_persistables`` , + the structured variable name will cannot be restored. You need to set the argument + ``use_structured_name=False`` when using ``Layer.set_state_dict`` later. + + Args: + path(str) : The path to load the target object. Generally, the path is the target + file path, when compatible with loading the saved results of + ``paddle.jit.save/paddle.static.save_inference_model`` , the path is a directory. + config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig` + object that specifies additional configuration options, these options + are for compatibility with ``paddle.jit.save/paddle.static.save_inference_model`` + formats. Default None. + + Returns: + Object(Object): a target object can be used in paddle + + Examples: + .. code-block:: python + + import paddle + + paddle.disable_static() + + emb = paddle.nn.Embedding(10, 10) + layer_state_dict = emb.state_dict() + paddle.save(layer_state_dict, "emb.pdparams") + + scheduler = paddle.optimizer.lr_scheduler.NoamLR( + d_model=0.01, warmup_steps=100, verbose=True) + adam = paddle.optimizer.Adam( + learning_rate=scheduler, + parameters=emb.parameters()) + opt_state_dict = adam.state_dict() + paddle.save(opt_state_dict, "adam.pdopt") + + load_layer_state_dict = paddle.load("emb.pdparams") + load_opt_state_dict = paddle.load("adam.pdopt") + ''' + # 1. input check + if not os.path.exists(path): + error_msg = "The path `%s` does not exist." + # if current path is a prefix, and the path.pdparams or path.pdopt + # is exist, users may want use `paddle.load` load the result of + # `fluid.save_dygraph`, we raise error here for users + params_file_path = path + ".pdparams" + opti_file_path = path + ".pdopt" + if os.path.exists(params_file_path) or os.path.exists(opti_file_path): + error_msg += " If you want to load the results saved by `fluid.save_dygraph`, " \ + "please specify the full file name, not just the file name prefix. For " \ + "example, it should be written as `paddle.load('model.pdparams')` instead of " \ + "`paddle.load('model')`." + raise ValueError(error_msg % path) + + if config is None: + config = paddle.SaveLoadConfig() + + # 2. load target + load_result = None + if os.path.isfile(path): + # we think path is file means this file is created by paddle.save + with open(path, 'rb') as f: + load_result = pickle.load(f) if six.PY2 else pickle.load( + f, encoding='latin1') + + if not config.keep_name_table and "StructuredToParameterName@@" in load_result: + del load_result["StructuredToParameterName@@"] + elif os.path.isdir(path): + # we think path is directory means compatible with loading + # store results of static mode related save APIs + + # check whether model file exists + if config.model_filename is None: + model_filename = '__model__' + else: + model_filename = config.model_filename + model_file_path = os.path.join(path, model_filename) + + if os.path.exists(model_file_path): + # Load state dict by `jit.save/io.save_inference_model` save format + # NOTE(chenweihang): [ Compatibility of save_inference_model save format ] + # The model saved by `save_inference_model` does not completely correspond to + # the information required by the `state_dict` under the dygraph. + # `save_inference_model` not save structured name, we need to remind + # the user to configure the `use_structured_name` argument when `set_state_dict` + # NOTE(chenweihang): `jit.save` doesn't save optimizer state + load_result = _load_state_dict_from_save_inference_model(path, + config) + else: + # load state dict by `io.save_params/persistables` save format + # TODO(chenweihang): [ Now only supports loading parameters seperately ] + # If users save all parameters as one file, the [ variable.name -> variable ] + # mapping info will lost, so users need to give variable list, but users build + # variable list in dygraph mode is difficult, we recommend users to use + # paddle.static.load_program_state in this case + load_result = _load_state_dict_from_save_params(path) + else: + raise ValueError( + "Unsupported path format, now only supports file or directory.") + + return load_result diff --git a/python/paddle/io/__init__.py b/python/paddle/io/__init__.py index 6f0b0f3c9c..92dd819b3c 100644 --- a/python/paddle/io/__init__.py +++ b/python/paddle/io/__init__.py @@ -25,16 +25,8 @@ __all__ = [ 'Sampler', 'SequenceSampler', 'RandomSampler', - 'load', - 'save', - 'load_program_state', - 'set_program_state', - 'load_inference_model', - 'save_inference_model', ] from ..fluid.io import DataLoader from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \ TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler -from ..fluid.io import load, save, load_program_state, set_program_state, \ - load_inference_model, save_inference_model, batch diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index 42a28a4f04..e0a9bc6eec 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -17,8 +17,9 @@ __all__ = [ 'append_backward', 'gradients', 'Executor', 'global_scope', 'scope_guard', 'BuildStrategy', 'CompiledProgram', 'Print', 'py_func', 'ExecutionStrategy', 'name_scope', 'ParallelExecutor', 'program_guard', 'WeightNormParamAttr', - 'default_main_program', 'default_startup_program', 'Program', 'save', - 'load', 'data', 'InputSpec' + 'default_main_program', 'default_startup_program', 'Program', 'data', + 'InputSpec', 'save', 'load', 'save_inference_model', 'load_inference_model', + 'load_program_state', 'set_program_state' ] from . import nn @@ -41,5 +42,9 @@ from ..fluid.layers.control_flow import Print #DEFINE_ALIAS from ..fluid.layers.nn import py_func #DEFINE_ALIAS from ..fluid.parallel_executor import ParallelExecutor #DEFINE_ALIAS from ..fluid.param_attr import WeightNormParamAttr #DEFINE_ALIAS -from ..tensor.io import save #DEFINE_ALIAS -from ..tensor.io import load #DEFINE_ALIAS +from ..fluid.io import save #DEFINE_ALIAS +from ..fluid.io import load #DEFINE_ALIAS +from ..fluid.io import save_inference_model #DEFINE_ALIAS +from ..fluid.io import load_inference_model #DEFINE_ALIAS +from ..fluid.io import load_program_state #DEFINE_ALIAS +from ..fluid.io import set_program_state #DEFINE_ALIAS diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index cec989fba8..b6bab16c96 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -42,8 +42,6 @@ from .creation import tril #DEFINE_ALIAS from .creation import meshgrid #DEFINE_ALIAS from .creation import empty #DEFINE_ALIAS from .creation import empty_like #DEFINE_ALIAS -from .io import save #DEFINE_ALIAS -from .io import load #DEFINE_ALIAS from .linalg import matmul #DEFINE_ALIAS from .linalg import dot #DEFINE_ALIAS # from .linalg import einsum #DEFINE_ALIAS diff --git a/python/paddle/tensor/io.py b/python/paddle/tensor/io.py deleted file mode 100644 index 66e956e8e4..0000000000 --- a/python/paddle/tensor/io.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# TODO: define functions to save & load a tensor -from ..fluid import save #DEFINE_ALIAS -from ..fluid.io import load #DEFINE_ALIAS - -__all__ = ['save', 'load'] -- GitLab