checkpoint.py 8.8 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
L
lujun 已提交
19
from ..framework import Variable, default_main_program
20 21 22
import pickle
from . import learning_rate_scheduler
import warnings
23 24 25 26

__all__ = ['save_persistables', 'load_persistables']


J
Jiabin Yang 已提交
27
def save_persistables(model_dict, dirname='save_dir', optimizers=None):
28 29
    """
    This function filters out all variables in layer.parameters from the
J
Jiabin Yang 已提交
30
    give `layer`, and optimizer's learning rate decay and then trys to load these variables from the folder
31 32 33 34 35 36 37 38
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
39
        model_dict(dict of Parameters): The parameters will
40 41 42
                                    be saved. If it is None, nothing
                                    will be deal.
        dirname(str): The directory path.
J
Jiabin Yang 已提交
43
        optimizers(fluid.Optimizer|list(fluid.Optimizer)|None): The optimizers to be saved
44 45 46 47 48 49 50 51 52 53 54

    Returns:

    Examples:
        .. code-block:: python
            ptb_model = PtbModel(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale)
J
Jiabin Yang 已提交
55
            sgd = fluid.optimizer.SGD(learning_rate=0.01)
56 57 58 59 60 61 62 63 64 65 66 67 68 69
            x_data = np.arange(12).reshape(4, 3).astype('int64')
            y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
            x_data = x_data.reshape((-1, num_steps, 1))
            y_data = y_data.reshape((-1, 1))
            init_hidden_data = np.zeros(
                (num_layers, batch_size, hidden_size), dtype='float32')
            init_cell_data = np.zeros(
                (num_layers, batch_size, hidden_size), dtype='float32')
            x = to_variable(x_data)
            y = to_variable(y_data)
            init_hidden = to_variable(init_hidden_data)
            init_cell = to_variable(init_cell_data)
            dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
                                                        init_cell)
J
Jiabin Yang 已提交
70 71 72
            dy_loss.backward()
            sgd.minimize(dy_loss)
            ptb_model.clear_gradient()
73
            param_path = "./my_paddle_model"
J
Jiabin Yang 已提交
74
            fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, sgd)
75
    """
76
    if isinstance(model_dict, collections.OrderedDict):
J
Jiabin Yang 已提交
77
        _save_var_to_file(model_dict, optimizers, dirname, None)
78 79


80
def load_persistables(dirname='save_dir'):
81 82 83 84 85 86 87 88 89 90
    """
    This function trys to load persistable variables from the folder
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
91
        dirname(str): The directory path. default is save_dir
92 93 94

    Returns:
        dict: The parameter-dict resumed from file
J
Jiabin Yang 已提交
95
        optimizer dict: The optimizer
96 97 98

    Examples:
        .. code-block:: python
99
            my_layer = layer(fluid.Layer)
100
            param_path = "./my_paddle_model"
J
Jiabin Yang 已提交
101 102
            sgd = SGDOptimizer(learning_rate=1e-3)
            param_dict, optimizer_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path)
103
            param_1 = param_dict['PtbModel_0.w_1']
J
Jiabin Yang 已提交
104
            sgd.load(optimizer_dict)
105 106

        """
L
lujun 已提交
107
    return _load_var_from_file(dirname)
108 109


110
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
111 112
    save_block = default_main_program().global_block()
    save_var_map = {}
113
    for var_key, each_var in stat_dict.items():
114 115 116 117 118 119
        save_var_map[each_var.name] = each_var
        if file_name is None:
            save_block.append_op(
                type='save',
                inputs={'X': [each_var]},
                outputs={},
120 121 122 123
                attrs={
                    'file_path': os.path.join(file_dir,
                                              os.path.normpath(each_var.name))
                })
J
Jiabin Yang 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

    if optimizers is not None:
        if isinstance(optimizers, (list, tuple)):
            optimizers = optimizers
        else:
            optimizers = [optimizers]
        if os.path.exists(
                os.path.join(file_dir, os.path.normpath("optimizers"))):
            pass
        else:
            os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
        for optimizer in optimizers:
            if isinstance(optimizer._learning_rate,
                          learning_rate_scheduler.LearningRateDecay):
                try:
                    f = open(
                        os.path.join(file_dir, "optimizers",
                                     os.path.normpath(str(optimizer._name))),
                        "wb")
                    pickle.dump(optimizer._learning_rate, f, 2)
                    f.close()
                except ():
                    raise IOError("Can't load %s",
                                  os.path.join(
                                      file_dir, "optimizers",
                                      os.path.normpath(str(optimizer._name))))
            else:
                warnings.warn(
                    "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
                )
154 155
    else:
        pass
156 157 158 159 160 161 162 163 164 165

    if file_name is not None:
        save_var_list = []
        for name in sorted(save_var_map.keys()):
            save_var_list.append(save_var_map[name])

        save_block.append_op(
            type='save_combine',
            inputs={'X': save_var_list},
            outputs={},
166 167 168
            attrs={
                'file_path': os.path.join(file_dir, os.path.normpath(file_name))
            })
169 170


L
lujun 已提交
171 172
def _load_var_from_file(file_dir):
    def walk_filename(file_dir):
L
lujun 已提交
173
        base_path = os.path.join(file_dir)
L
lujun 已提交
174
        var_name_list = []
L
lujun 已提交
175 176
        if os.path.exists(base_path):
            for dirpath, dirnames, filenames in os.walk(base_path):
177 178
                if "optimizers" in dirpath:
                    continue
L
lujun 已提交
179 180 181
                pt = dirpath.replace(base_path, "", 1)
                if pt.startswith("/") or pt.startswith("\\"):
                    pt = pt[1:]
L
lujun 已提交
182 183
                for fth_name in filenames:
                    if fth_name[0] != '.':
L
lujun 已提交
184 185 186 187
                        name_path = os.path.join(pt, fth_name)
                        if "\\" in name_path:
                            name_path = name_path.replace("\\", "/")
                        var_name_list.append(name_path)
188

L
lujun 已提交
189
        return var_name_list
190

L
lujun 已提交
191 192
    load_block = default_main_program().global_block()
    load_var_map = {}
193
    load_optimizer_map = {}
L
lujun 已提交
194 195 196
    file_var_list = walk_filename(file_dir)
    for var_name in file_var_list:
        new_var = Variable(block=load_block, name=var_name)
197
        load_block.append_op(
L
lujun 已提交
198
            type='load',
199
            inputs={},
L
lujun 已提交
200
            outputs={'Out': [new_var]},
201
            attrs={
L
lujun 已提交
202 203
                'file_path': os.path.join(file_dir,
                                          os.path.normpath(new_var.name))
204
            })
L
lujun 已提交
205 206

        load_var_map[new_var.name] = new_var
207 208 209 210 211 212 213 214 215 216 217 218 219
    opt_path = os.path.join(file_dir, "optimizers")
    for _, _, optimizers in os.walk(opt_path):
        for optimizer in optimizers:
            try:
                f = open(os.path.join(opt_path, optimizer), "rb")
                load_optimizer_map[optimizer] = pickle.load(f)
                f.close()
            except IOError:
                raise IOError("Can't load %s",
                              os.path.join(
                                  file_dir, "optimizers",
                                  os.path.normpath(str(optimizer._name))))
    if len(load_optimizer_map) == 0:
J
Jiabin Yang 已提交
220 221 222 223
        print(
            "No optimizer loaded. If you didn't save optimizer, please ignore this. The program can still work with new optimizer. "
        )
        pass
224 225

    return load_var_map, load_optimizer_map
226 227 228 229 230 231 232 233 234


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.dtype,
        type=var.type,
235
        lod_level=0,
236
        persistable=True)