checkpoint.py 4.8 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
H
hong 已提交
19
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
20
import pickle
21
import six
22 23
from . import learning_rate_scheduler
import warnings
H
hong 已提交
24
from .. import core
25

H
hong 已提交
26 27 28 29
__all__ = [
    'save_dygraph',
    'load_dygraph',
]
30 31


H
hong 已提交
32 33 34 35 36 37 38
@dygraph_only
def save_dygraph(state_dict, model_path):
    '''
    Save Layer's state_dict to disk. This will generate a file with suffix ".pdparams"
    
    The state_dict is get from Layers.state_dict function
    
39
    Args:
H
hong 已提交
40 41
        state_dict(dict) : The state dict to be saved.
        model_path(str) : the file prefix to save the state_dict. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
42 43

    Returns:
L
lujun 已提交
44
        None
45 46

    Examples:
H
hong 已提交
47 48 49 50 51
        .. code-block:: python

            import paddle.fluid as fluid

            with fluid.dygraph.guard():
52
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
53 54 55 56

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

57 58
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
59 60 61 62 63 64 65 66 67 68 69 70 71

                state_dict = adam.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

    '''

    base_name = os.path.basename(model_path)
    assert base_name != "", "model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"

    suffix = ".pdparams"
    assert len(state_dict) > 0, "state_dict is empty, no need to save"

    for k, v in state_dict.items():
H
hong 已提交
72
        if not isinstance(v, ParamBase):
H
hong 已提交
73 74 75
            suffix = ".pdopt"
        break

H
hong 已提交
76 77 78 79 80 81 82 83 84 85
    model_dict = {}
    name_table = {}
    for k, v in state_dict.items():
        if isinstance(v, (Variable, core.VarBase)):
            model_dict[k] = v.numpy()
        else:
            model_dict[k] = v
        name_table[k] = v.name
    model_dict["StructuredToParameterName@@"] = name_table

86 87 88 89 90 91
    file_name = model_path + suffix
    dir_name = os.path.dirname(file_name)
    if dir_name and not os.path.exists(dir_name):
        os.makedirs(dir_name)

    with open(file_name, 'wb') as f:
92
        pickle.dump(model_dict, f, protocol=2)
H
hong 已提交
93 94 95


@dygraph_only
H
hong 已提交
96
def load_dygraph(model_path, keep_name_table=False):
H
hong 已提交
97 98 99 100 101
    '''
    Load parameter state_dict from disk.

    Args:
        model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams') 
H
hong 已提交
102 103
        keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict. 
                                          Default : False
H
hong 已提交
104 105 106

    Returns:
        state_dict(dict) : the dict store the state_dict
L
lujun 已提交
107

H
hong 已提交
108
    Examples:
109
        .. code-block:: python
L
lujun 已提交
110

H
hong 已提交
111 112 113
            import paddle.fluid as fluid
            
            with fluid.dygraph.guard():
114
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
115 116 117 118

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

119 120
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
121
                state_dict = adam.state_dict()
122
                fluid.save_dygraph( state_dict, "paddle_dy")
H
hong 已提交
123 124 125 126 127 128 129 130 131 132

                para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")

    '''

    params_file_path = model_path + ".pdparams"
    if not os.path.exists(params_file_path):
        raise RuntimeError("Parameter file [ {} ] not exists".format(
            params_file_path))

H
hong 已提交
133
    with open(params_file_path, 'rb') as f:
134
        para_dict = pickle.load(f) if six.PY2 else pickle.load(
135
            f, encoding='latin1')
H
hong 已提交
136

H
hong 已提交
137 138
    if not keep_name_table and "StructuredToParameterName@@" in para_dict:
        del para_dict["StructuredToParameterName@@"]
H
hong 已提交
139 140 141
    opti_dict = None
    opti_file_path = model_path + ".pdopt"
    if os.path.exists(opti_file_path):
H
hong 已提交
142
        with open(opti_file_path, 'rb') as f:
143
            opti_dict = pickle.load(f) if six.PY2 else pickle.load(
144
                f, encoding='latin1')
H
hong 已提交
145 146

    return para_dict, opti_dict