checkpoint.py 5.1 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
H
hong 已提交
19
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
20
import pickle
21
import six
22 23
from . import learning_rate_scheduler
import warnings
H
hong 已提交
24
from .. import core
25

H
hong 已提交
26 27 28 29
__all__ = [
    'save_dygraph',
    'load_dygraph',
]
30 31


H
hong 已提交
32 33 34
@dygraph_only
def save_dygraph(state_dict, model_path):
    '''
35 36
    :api_attr: imperative

H
hong 已提交
37 38 39 40
    Save Layer's state_dict to disk. This will generate a file with suffix ".pdparams"
    
    The state_dict is get from Layers.state_dict function
    
41
    Args:
H
hong 已提交
42 43
        state_dict(dict) : The state dict to be saved.
        model_path(str) : the file prefix to save the state_dict. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
44 45

    Returns:
L
lujun 已提交
46
        None
47 48

    Examples:
H
hong 已提交
49 50 51 52 53
        .. code-block:: python

            import paddle.fluid as fluid

            with fluid.dygraph.guard():
54
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
55 56 57 58

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

59 60
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
61 62 63 64 65 66 67

                state_dict = adam.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

    '''

    base_name = os.path.basename(model_path)
68
    assert base_name != "", "The input model_path MUST be format of dirname/filename [dirname\\filename in Windows system], but received filename is empty string."
H
hong 已提交
69 70 71 72 73

    suffix = ".pdparams"
    assert len(state_dict) > 0, "state_dict is empty, no need to save"

    for k, v in state_dict.items():
H
hong 已提交
74
        if not isinstance(v, ParamBase):
H
hong 已提交
75 76 77
            suffix = ".pdopt"
        break

H
hong 已提交
78 79 80 81 82 83 84 85 86 87
    model_dict = {}
    name_table = {}
    for k, v in state_dict.items():
        if isinstance(v, (Variable, core.VarBase)):
            model_dict[k] = v.numpy()
        else:
            model_dict[k] = v
        name_table[k] = v.name
    model_dict["StructuredToParameterName@@"] = name_table

88 89 90 91 92 93
    file_name = model_path + suffix
    dir_name = os.path.dirname(file_name)
    if dir_name and not os.path.exists(dir_name):
        os.makedirs(dir_name)

    with open(file_name, 'wb') as f:
94
        pickle.dump(model_dict, f, protocol=2)
H
hong 已提交
95 96 97


@dygraph_only
H
hong 已提交
98
def load_dygraph(model_path, keep_name_table=False):
H
hong 已提交
99
    '''
100 101
    :api_attr: imperative
    
H
hong 已提交
102 103 104 105
    Load parameter state_dict from disk.

    Args:
        model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams') 
H
hong 已提交
106 107
        keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict. 
                                          Default : False
H
hong 已提交
108 109 110

    Returns:
        state_dict(dict) : the dict store the state_dict
L
lujun 已提交
111

H
hong 已提交
112
    Examples:
113
        .. code-block:: python
L
lujun 已提交
114

H
hong 已提交
115 116 117
            import paddle.fluid as fluid
            
            with fluid.dygraph.guard():
118
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
119 120 121 122

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

123 124
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
125
                state_dict = adam.state_dict()
126
                fluid.save_dygraph( state_dict, "paddle_dy")
H
hong 已提交
127 128 129 130 131

                para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")

    '''

132 133 134 135 136 137 138
    model_prefix = model_path
    if model_prefix.endswith(".pdparams"):
        model_prefix = model_prefix[:-9]
    elif model_prefix.endswith(".pdopt"):
        model_prefix = model_prefix[:-6]

    params_file_path = model_prefix + ".pdparams"
H
hong 已提交
139 140 141 142
    if not os.path.exists(params_file_path):
        raise RuntimeError("Parameter file [ {} ] not exists".format(
            params_file_path))

H
hong 已提交
143
    with open(params_file_path, 'rb') as f:
144
        para_dict = pickle.load(f) if six.PY2 else pickle.load(
145
            f, encoding='latin1')
H
hong 已提交
146

H
hong 已提交
147 148
    if not keep_name_table and "StructuredToParameterName@@" in para_dict:
        del para_dict["StructuredToParameterName@@"]
H
hong 已提交
149
    opti_dict = None
150
    opti_file_path = model_prefix + ".pdopt"
H
hong 已提交
151
    if os.path.exists(opti_file_path):
H
hong 已提交
152
        with open(opti_file_path, 'rb') as f:
153
            opti_dict = pickle.load(f) if six.PY2 else pickle.load(
154
                f, encoding='latin1')
H
hong 已提交
155 156

    return para_dict, opti_dict