checkpoint.py 4.6 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
H
hong 已提交
19
from ..framework import Variable, default_main_program, in_dygraph_mode, dygraph_only, Parameter, ParamBase
20 21 22
import pickle
from . import learning_rate_scheduler
import warnings
H
hong 已提交
23
from .. import core
24

H
hong 已提交
25 26 27 28
__all__ = [
    'save_dygraph',
    'load_dygraph',
]
29 30


H
hong 已提交
31 32 33 34 35 36 37
@dygraph_only
def save_dygraph(state_dict, model_path):
    '''
    Save Layer's state_dict to disk. This will generate a file with suffix ".pdparams"
    
    The state_dict is get from Layers.state_dict function
    
38
    Args:
H
hong 已提交
39 40
        state_dict(dict) : The state dict to be saved.
        model_path(str) : the file prefix to save the state_dict. The format is "dirname/file_prefix". If file_prefix is empty str. A exception will be raised
41 42

    Returns:
L
lujun 已提交
43
        None
44 45

    Examples:
H
hong 已提交
46 47 48 49 50
        .. code-block:: python

            import paddle.fluid as fluid

            with fluid.dygraph.guard():
51
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
52 53 54 55

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

56 57
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70

                state_dict = adam.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

    '''

    base_name = os.path.basename(model_path)
    assert base_name != "", "model_path MUST be format of dirname/filename [dirname\\filename in Window], Now filename is empty str"

    suffix = ".pdparams"
    assert len(state_dict) > 0, "state_dict is empty, no need to save"

    for k, v in state_dict.items():
H
hong 已提交
71
        if not isinstance(v, ParamBase):
H
hong 已提交
72 73 74
            suffix = ".pdopt"
        break

H
hong 已提交
75 76 77 78 79 80 81 82 83 84
    model_dict = {}
    name_table = {}
    for k, v in state_dict.items():
        if isinstance(v, (Variable, core.VarBase)):
            model_dict[k] = v.numpy()
        else:
            model_dict[k] = v
        name_table[k] = v.name
    model_dict["StructuredToParameterName@@"] = name_table

85 86 87 88 89 90
    file_name = model_path + suffix
    dir_name = os.path.dirname(file_name)
    if dir_name and not os.path.exists(dir_name):
        os.makedirs(dir_name)

    with open(file_name, 'wb') as f:
H
hong 已提交
91
        pickle.dump(model_dict, f)
H
hong 已提交
92 93 94


@dygraph_only
H
hong 已提交
95
def load_dygraph(model_path, keep_name_table=False):
H
hong 已提交
96 97 98 99 100
    '''
    Load parameter state_dict from disk.

    Args:
        model_path(str) : The file prefix store the state_dict. (The path should Not contain suffix '.pdparams') 
H
hong 已提交
101 102
        keep_name_table(bool, optional) : Whether keep structed name to parameter name conversion table in output dict. 
                                          Default : False
H
hong 已提交
103 104 105

    Returns:
        state_dict(dict) : the dict store the state_dict
L
lujun 已提交
106

H
hong 已提交
107
    Examples:
108
        .. code-block:: python
L
lujun 已提交
109

H
hong 已提交
110 111 112
            import paddle.fluid as fluid
            
            with fluid.dygraph.guard():
113
                emb = fluid.dygraph.Embedding([10, 10])
H
hong 已提交
114 115 116 117

                state_dict = emb.state_dict()
                fluid.save_dygraph( state_dict, "paddle_dy")

118 119
                adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                                             parameter_list = emb.parameters() )
H
hong 已提交
120
                state_dict = adam.state_dict()
121
                fluid.save_dygraph( state_dict, "paddle_dy")
H
hong 已提交
122 123 124 125 126 127 128 129 130 131

                para_state_dict, opti_state_dict = fluid.load_dygraph( "paddle_dy")

    '''

    params_file_path = model_path + ".pdparams"
    if not os.path.exists(params_file_path):
        raise RuntimeError("Parameter file [ {} ] not exists".format(
            params_file_path))

H
hong 已提交
132 133
    with open(params_file_path, 'rb') as f:
        para_dict = pickle.load(f)
H
hong 已提交
134

H
hong 已提交
135 136
    if not keep_name_table and "StructuredToParameterName@@" in para_dict:
        del para_dict["StructuredToParameterName@@"]
H
hong 已提交
137 138 139
    opti_dict = None
    opti_file_path = model_path + ".pdopt"
    if os.path.exists(opti_file_path):
H
hong 已提交
140 141
        with open(opti_file_path, 'rb') as f:
            opti_dict = pickle.load(f)
H
hong 已提交
142 143

    return para_dict, opti_dict