checkpoint.py 8.8 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
19
import functools
20
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase, _varbase_creator, _dygraph_tracer
21
import pickle
22
import six
23 24
from . import learning_rate_scheduler
import warnings
H
hong 已提交
25
from .. import core
26 27 28
from .base import guard
from paddle.fluid.dygraph.jit import SaveLoadConfig, deprecate_save_load_configs
from paddle.fluid.dygraph.io import _construct_program_holders, _construct_params_and_buffers
29

H
hong 已提交
30 31 32 33
__all__ = [
    'save_dygraph',
    'load_dygraph',
]
34 35


36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
# NOTE(chenweihang): deprecate load_dygraph's argument keep_name_table,
# ensure compatibility when user still use keep_name_table argument
def deprecate_keep_name_table(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        def __warn_and_build_configs__(keep_name_table):
            warnings.warn(
                "The argument `keep_name_table` has deprecated, please use `SaveLoadConfig.keep_name_table`.",
                DeprecationWarning)
            config = SaveLoadConfig()
            config.keep_name_table = keep_name_table
            return config

        # deal with arg `keep_name_table`
        if len(args) > 1 and isinstance(args[1], bool):
            args = list(args)
            args[1] = __warn_and_build_configs__(args[1])
        # deal with kwargs
        elif 'keep_name_table' in kwargs:
            kwargs['config'] = __warn_and_build_configs__(kwargs[
                'keep_name_table'])
            kwargs.pop('keep_name_table')
        else:
            # do nothing
            pass

        return func(*args, **kwargs)

    return wrapper


H
hong 已提交
67 68 69
@dygraph_only
def save_dygraph(state_dict, model_path):
    '''
70 71
    :api_attr: imperative

H
hong 已提交
72 73 74 75
    Save Layer's state_dict to disk. This will generate a file with suffix ".pdparams"
    
    The state_dict is get from Layers.state_dict function
    
76
    Args:
H
hong 已提交
77 78
        state_dict(dict) : The state dict to be saved.
        model_path(str) : the file prefix to save the state_dict. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
79 80

    Returns:
L
lujun 已提交
81
        None
82 83

    Examples:
H
hong 已提交
84 85 86 87 88
        .. code-block:: python

            import paddle.fluid as fluid

            with fluid.dygraph.guard():
89
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
90 91 92 93

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

94 95
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
96 97 98 99 100 101 102

                state_dict = adam.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

    '''

    base_name = os.path.basename(model_path)
103
    assert base_name != "", "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received filename is empty string."
H
hong 已提交
104 105 106 107

    suffix = ".pdparams"
    assert len(state_dict) > 0, "state_dict is empty, no need to save"

108
    param_num = 0
H
hong 已提交
109
    for k, v in state_dict.items():
110 111 112 113 114
        if isinstance(v, ParamBase):
            param_num += 1

    if param_num == 0:
        suffix = ".pdopt"
H
hong 已提交
115

H
hong 已提交
116 117 118 119 120
    model_dict = {}
    name_table = {}
    for k, v in state_dict.items():
        if isinstance(v, (Variable, core.VarBase)):
            model_dict[k] = v.numpy()
121
            name_table[k] = v.name
H
hong 已提交
122 123 124 125
        else:
            model_dict[k] = v
    model_dict["StructuredToParameterName@@"] = name_table

126 127 128 129 130 131
    file_name = model_path + suffix
    dir_name = os.path.dirname(file_name)
    if dir_name and not os.path.exists(dir_name):
        os.makedirs(dir_name)

    with open(file_name, 'wb') as f:
132
        pickle.dump(model_dict, f, protocol=2)
H
hong 已提交
133 134


135 136
# TODO(qingqing01): remove dygraph_only to support loading static model.
# maybe need to unify the loading interface after 2.0 API is ready.
137 138 139 140
# @dygraph_only
@deprecate_save_load_configs
@deprecate_keep_name_table
def load_dygraph(model_path, config=None):
H
hong 已提交
141
    '''
142 143
    :api_attr: imperative
    
144 145 146 147 148 149 150
    Load parameter state dict from disk.

    .. note::
        Due to some historical reasons, if you load ``state_dict`` from the saved 
        result of `paddle.io.save_inference_model`, the structured variable name 
        will cannot be restored. You need to set the argument `use_structured_name=False` 
        when using `Layer.set_state_dict` later.
H
hong 已提交
151 152

    Args:
153 154 155 156 157 158
        model_path(str) : The file prefix store the state_dict. 
            (The path should Not contain suffix '.pdparams') 
        config (SaveLoadConfig, optional): :ref:`api_imperative_jit_saveLoadConfig`
            object that specifies additional configuration options, these options 
            are for compatibility with ``jit.save/io.save_inference_model`` formats. 
            Default None.
H
hong 已提交
159 160 161

    Returns:
        state_dict(dict) : the dict store the state_dict
L
lujun 已提交
162

H
hong 已提交
163
    Examples:
164
        .. code-block:: python
L
lujun 已提交
165

166
            import paddle
H
hong 已提交
167
            
168
            paddle.disable_static()
H
hong 已提交
169

170
            emb = paddle.nn.Embedding([10, 10])
H
hong 已提交
171

172 173
            state_dict = emb.state_dict()
            paddle.save(state_dict, "paddle_dy")
H
hong 已提交
174

175 176 177 178 179 180 181
            scheduler = paddle.optimizer.lr_scheduler.NoamLR(
                d_model=0.01, warmup_steps=100, verbose=True)
            adam = paddle.optimizer.Adam(
                learning_rate=scheduler,
                parameters=emb.parameters())
            state_dict = adam.state_dict()
            paddle.save(state_dict, "paddle_dy")
H
hong 已提交
182

183
            para_state_dict, opti_state_dict = paddle.load("paddle_dy")
H
hong 已提交
184

185 186
    '''
    # deal with argument `model_path`
187 188 189 190 191 192
    model_prefix = model_path
    if model_prefix.endswith(".pdparams"):
        model_prefix = model_prefix[:-9]
    elif model_prefix.endswith(".pdopt"):
        model_prefix = model_prefix[:-6]

193
    para_dict = None
H
hong 已提交
194
    opti_dict = None
195
    params_file_path = model_prefix + ".pdparams"
196
    opti_file_path = model_prefix + ".pdopt"
197 198 199 200 201 202

    # deal with argument `configs`
    configs = config
    if configs is None:
        configs = SaveLoadConfig()

203 204
    if not os.path.exists(params_file_path) and not os.path.exists(
            opti_file_path):
205 206
        # Load state dict by `jit.save/io.save_inference_model` save format
        # NOTE(chenweihang): [ Compatibility of save_inference_model save format ]
207 208
        # The model saved by `save_inference_model` does not completely correspond to 
        # the information required by the `state_dict` under the dygraph. 
209 210 211
        # `save_inference_model` not save structured name, we need to remind 
        # the user to configure the `use_structured_name` argument when `set_state_dict`
        # NOTE(chenweihang): `jit.save` doesn't save optimizer state 
212 213 214 215 216

        # 1. check model path
        if not os.path.isdir(model_prefix):
            raise ValueError("Model saved directory '%s' is not exists." %
                             model_prefix)
217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235

        # 2. load program desc & construct _ProgramHolder
        programs = _construct_program_holders(model_path,
                                              configs.model_filename)

        # 3. load layer parameters & buffers
        # NOTE: using fluid.dygraph.guard() here will cause import error in py2
        with guard():
            persistable_var_dict = _construct_params_and_buffers(
                model_prefix,
                programs,
                configs.separate_params,
                configs.params_filename,
                append_suffix=False)

            # 4. construct state_dict
            para_dict = dict()
            for var_name in persistable_var_dict:
                para_dict[var_name] = persistable_var_dict[var_name].numpy()
236 237
    else:
        # Load state dict by `save_dygraph` save format
238
        para_dict = {}
239 240 241 242 243
        if os.path.exists(params_file_path):
            with open(params_file_path, 'rb') as f:
                para_dict = pickle.load(f) if six.PY2 else pickle.load(
                    f, encoding='latin1')

244
        if not configs.keep_name_table and "StructuredToParameterName@@" in para_dict:
245 246 247 248 249 250
            del para_dict["StructuredToParameterName@@"]

        if os.path.exists(opti_file_path):
            with open(opti_file_path, 'rb') as f:
                opti_dict = pickle.load(f) if six.PY2 else pickle.load(
                    f, encoding='latin1')
H
hong 已提交
251 252

    return para_dict, opti_dict