checkpoint.py 8.9 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
L
lujun 已提交
19
from ..framework import Variable, default_main_program
20 21 22
import pickle
from . import learning_rate_scheduler
import warnings
23 24 25 26

__all__ = ['save_persistables', 'load_persistables']


J
Jiabin Yang 已提交
27
def save_persistables(model_dict, dirname='save_dir', optimizers=None):
28 29
    """
    This function filters out all variables in layer.parameters from the
J
Jiabin Yang 已提交
30
    give `layer`, and optimizer's learning rate decay and then trys to load these variables from the folder
31 32 33 34 35 36 37 38
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
39
        model_dict(dict of Parameters): The parameters will
40 41 42
                                    be saved. If it is None, nothing
                                    will be deal.
        dirname(str): The directory path.
J
Jiabin Yang 已提交
43
        optimizers(fluid.Optimizer|list(fluid.Optimizer)|None): The optimizers to be saved
44 45

    Returns:
L
lujun 已提交
46
        None
47 48

    Examples:
L
lujun 已提交
49

50
        .. code-block:: python
L
lujun 已提交
51 52

          ptb_model = PtbModel(
53 54 55 56 57
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale)
L
lujun 已提交
58 59 60 61 62 63
          sgd = fluid.optimizer.SGD(learning_rate=0.01)
          x_data = np.arange(12).reshape(4, 3).astype('int64')
          y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
          x_data = x_data.reshape((-1, num_steps, 1))
          y_data = y_data.reshape((-1, 1))
          init_hidden_data = np.zeros(
64
                (num_layers, batch_size, hidden_size), dtype='float32')
L
lujun 已提交
65
          init_cell_data = np.zeros(
66
                (num_layers, batch_size, hidden_size), dtype='float32')
L
lujun 已提交
67 68 69 70 71
          x = to_variable(x_data)
          y = to_variable(y_data)
          init_hidden = to_variable(init_hidden_data)
          init_cell = to_variable(init_cell_data)
          dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
72
                                                        init_cell)
L
lujun 已提交
73 74 75 76 77
          dy_loss.backward()
          sgd.minimize(dy_loss)
          ptb_model.clear_gradient()
          param_path = "./my_paddle_model"
          fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, sgd)
78
    """
79
    if isinstance(model_dict, collections.OrderedDict):
J
Jiabin Yang 已提交
80
        _save_var_to_file(model_dict, optimizers, dirname, None)
81 82


83
def load_persistables(dirname='save_dir'):
84 85 86 87 88 89 90 91 92 93
    """
    This function trys to load persistable variables from the folder
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
94
        dirname(str): The directory path. default is save_dir
95 96 97

    Returns:
        dict: The parameter-dict resumed from file
J
Jiabin Yang 已提交
98
        optimizer dict: The optimizer
99 100

    Examples:
L
lujun 已提交
101 102 103 104 105 106 107 108 109

         .. code-block:: python

           my_layer = layer(fluid.Layer)
           param_path = "./my_paddle_model"
           sgd = SGDOptimizer(learning_rate=1e-3)
           param_dict, optimizer_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path)
           param_1 = param_dict['PtbModel_0.w_1']
           sgd.load(optimizer_dict)
110 111

        """
L
lujun 已提交
112
    return _load_var_from_file(dirname)
113 114


115
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
116 117
    save_block = default_main_program().global_block()
    save_var_map = {}
118
    for var_key, each_var in stat_dict.items():
119 120 121 122 123 124
        save_var_map[each_var.name] = each_var
        if file_name is None:
            save_block.append_op(
                type='save',
                inputs={'X': [each_var]},
                outputs={},
125 126 127 128
                attrs={
                    'file_path': os.path.join(file_dir,
                                              os.path.normpath(each_var.name))
                })
J
Jiabin Yang 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    if optimizers is not None:
        if isinstance(optimizers, (list, tuple)):
            optimizers = optimizers
        else:
            optimizers = [optimizers]
        if os.path.exists(
                os.path.join(file_dir, os.path.normpath("optimizers"))):
            pass
        else:
            os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
        for optimizer in optimizers:
            if isinstance(optimizer._learning_rate,
                          learning_rate_scheduler.LearningRateDecay):
                try:
                    f = open(
                        os.path.join(file_dir, "optimizers",
                                     os.path.normpath(str(optimizer._name))),
                        "wb")
                    pickle.dump(optimizer._learning_rate, f, 2)
                    f.close()
                except ():
                    raise IOError("Can't load %s",
                                  os.path.join(
                                      file_dir, "optimizers",
                                      os.path.normpath(str(optimizer._name))))
            else:
                warnings.warn(
                    "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
                )
159 160
    else:
        pass
161 162 163 164 165 166 167 168 169 170

    if file_name is not None:
        save_var_list = []
        for name in sorted(save_var_map.keys()):
            save_var_list.append(save_var_map[name])

        save_block.append_op(
            type='save_combine',
            inputs={'X': save_var_list},
            outputs={},
171 172 173
            attrs={
                'file_path': os.path.join(file_dir, os.path.normpath(file_name))
            })
174 175


L
lujun 已提交
176
def _load_var_from_file(file_dir):
177 178 179
    if not os.path.exists(file_dir):
        raise IOError("{} not exist".format(file_dir))

L
lujun 已提交
180
    def walk_filename(file_dir):
L
lujun 已提交
181
        base_path = os.path.join(file_dir)
L
lujun 已提交
182
        var_name_list = []
L
lujun 已提交
183 184
        if os.path.exists(base_path):
            for dirpath, dirnames, filenames in os.walk(base_path):
185 186
                if "optimizers" in dirpath:
                    continue
L
lujun 已提交
187 188 189
                pt = dirpath.replace(base_path, "", 1)
                if pt.startswith("/") or pt.startswith("\\"):
                    pt = pt[1:]
L
lujun 已提交
190 191
                for fth_name in filenames:
                    if fth_name[0] != '.':
L
lujun 已提交
192 193 194 195
                        name_path = os.path.join(pt, fth_name)
                        if "\\" in name_path:
                            name_path = name_path.replace("\\", "/")
                        var_name_list.append(name_path)
196

L
lujun 已提交
197
        return var_name_list
198

L
lujun 已提交
199 200
    load_block = default_main_program().global_block()
    load_var_map = {}
201
    load_optimizer_map = {}
L
lujun 已提交
202 203 204
    file_var_list = walk_filename(file_dir)
    for var_name in file_var_list:
        new_var = Variable(block=load_block, name=var_name)
205
        load_block.append_op(
L
lujun 已提交
206
            type='load',
207
            inputs={},
L
lujun 已提交
208
            outputs={'Out': [new_var]},
209
            attrs={
L
lujun 已提交
210 211
                'file_path': os.path.join(file_dir,
                                          os.path.normpath(new_var.name))
212
            })
L
lujun 已提交
213 214

        load_var_map[new_var.name] = new_var
215 216 217 218 219 220 221 222 223 224 225 226 227
    opt_path = os.path.join(file_dir, "optimizers")
    for _, _, optimizers in os.walk(opt_path):
        for optimizer in optimizers:
            try:
                f = open(os.path.join(opt_path, optimizer), "rb")
                load_optimizer_map[optimizer] = pickle.load(f)
                f.close()
            except IOError:
                raise IOError("Can't load %s",
                              os.path.join(
                                  file_dir, "optimizers",
                                  os.path.normpath(str(optimizer._name))))
    if len(load_optimizer_map) == 0:
J
Jiabin Yang 已提交
228 229 230 231
        print(
            "No optimizer loaded. If you didn't save optimizer, please ignore this. The program can still work with new optimizer. "
        )
        pass
232 233

    return load_var_map, load_optimizer_map
234 235 236 237 238 239 240 241 242


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.dtype,
        type=var.type,
243
        lod_level=0,
244
        persistable=True)