checkpoint.py 8.8 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
L
lujun 已提交
19
from ..framework import Variable, default_main_program
20 21 22
import pickle
from . import learning_rate_scheduler
import warnings
23 24 25 26

__all__ = ['save_persistables', 'load_persistables']


J
Jiabin Yang 已提交
27
def save_persistables(model_dict, dirname='save_dir', optimizers=None):
28 29
    """
    This function filters out all variables in layer.parameters from the
J
Jiabin Yang 已提交
30
    give `layer`, and optimizer's learning rate decay and then trys to load these variables from the folder
31 32 33 34 35 36 37 38
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
39
        model_dict(dict of Parameters): The parameters will
40 41 42
                                    be saved. If it is None, nothing
                                    will be deal.
        dirname(str): The directory path.
J
Jiabin Yang 已提交
43
        optimizers(fluid.Optimizer|list(fluid.Optimizer)|None): The optimizers to be saved
44 45

    Returns:
L
lujun 已提交
46
        None
47 48

    Examples:
L
lujun 已提交
49

50
        .. code-block:: python
L
lujun 已提交
51 52

          ptb_model = PtbModel(
53 54 55 56 57
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale)
L
lujun 已提交
58 59 60 61 62 63
          sgd = fluid.optimizer.SGD(learning_rate=0.01)
          x_data = np.arange(12).reshape(4, 3).astype('int64')
          y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
          x_data = x_data.reshape((-1, num_steps, 1))
          y_data = y_data.reshape((-1, 1))
          init_hidden_data = np.zeros(
64
                (num_layers, batch_size, hidden_size), dtype='float32')
L
lujun 已提交
65
          init_cell_data = np.zeros(
66
                (num_layers, batch_size, hidden_size), dtype='float32')
L
lujun 已提交
67 68 69 70 71
          x = to_variable(x_data)
          y = to_variable(y_data)
          init_hidden = to_variable(init_hidden_data)
          init_cell = to_variable(init_cell_data)
          dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
72
                                                        init_cell)
L
lujun 已提交
73 74 75 76 77
          dy_loss.backward()
          sgd.minimize(dy_loss)
          ptb_model.clear_gradient()
          param_path = "./my_paddle_model"
          fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, sgd)
78
    """
79
    if isinstance(model_dict, collections.OrderedDict):
J
Jiabin Yang 已提交
80
        _save_var_to_file(model_dict, optimizers, dirname, None)
81 82


83
def load_persistables(dirname='save_dir'):
84 85 86 87 88 89 90 91 92 93
    """
    This function trys to load persistable variables from the folder
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
94
        dirname(str): The directory path. default is save_dir
95 96 97

    Returns:
        dict: The parameter-dict resumed from file
J
Jiabin Yang 已提交
98
        optimizer dict: The optimizer
99 100

    Examples:
L
lujun 已提交
101 102 103 104 105 106 107 108 109

         .. code-block:: python

           my_layer = layer(fluid.Layer)
           param_path = "./my_paddle_model"
           sgd = SGDOptimizer(learning_rate=1e-3)
           param_dict, optimizer_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path)
           param_1 = param_dict['PtbModel_0.w_1']
           sgd.load(optimizer_dict)
110 111

        """
L
lujun 已提交
112
    return _load_var_from_file(dirname)
113 114


115
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
116 117
    save_block = default_main_program().global_block()
    save_var_map = {}
118
    for var_key, each_var in stat_dict.items():
119 120 121 122 123 124
        save_var_map[each_var.name] = each_var
        if file_name is None:
            save_block.append_op(
                type='save',
                inputs={'X': [each_var]},
                outputs={},
125 126 127 128
                attrs={
                    'file_path': os.path.join(file_dir,
                                              os.path.normpath(each_var.name))
                })
J
Jiabin Yang 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158

    if optimizers is not None:
        if isinstance(optimizers, (list, tuple)):
            optimizers = optimizers
        else:
            optimizers = [optimizers]
        if os.path.exists(
                os.path.join(file_dir, os.path.normpath("optimizers"))):
            pass
        else:
            os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
        for optimizer in optimizers:
            if isinstance(optimizer._learning_rate,
                          learning_rate_scheduler.LearningRateDecay):
                try:
                    f = open(
                        os.path.join(file_dir, "optimizers",
                                     os.path.normpath(str(optimizer._name))),
                        "wb")
                    pickle.dump(optimizer._learning_rate, f, 2)
                    f.close()
                except ():
                    raise IOError("Can't load %s",
                                  os.path.join(
                                      file_dir, "optimizers",
                                      os.path.normpath(str(optimizer._name))))
            else:
                warnings.warn(
                    "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
                )
159 160
    else:
        pass
161 162 163 164 165 166 167 168 169 170

    if file_name is not None:
        save_var_list = []
        for name in sorted(save_var_map.keys()):
            save_var_list.append(save_var_map[name])

        save_block.append_op(
            type='save_combine',
            inputs={'X': save_var_list},
            outputs={},
171 172 173
            attrs={
                'file_path': os.path.join(file_dir, os.path.normpath(file_name))
            })
174 175


L
lujun 已提交
176 177
def _load_var_from_file(file_dir):
    def walk_filename(file_dir):
L
lujun 已提交
178
        base_path = os.path.join(file_dir)
L
lujun 已提交
179
        var_name_list = []
L
lujun 已提交
180 181
        if os.path.exists(base_path):
            for dirpath, dirnames, filenames in os.walk(base_path):
182 183
                if "optimizers" in dirpath:
                    continue
L
lujun 已提交
184 185 186
                pt = dirpath.replace(base_path, "", 1)
                if pt.startswith("/") or pt.startswith("\\"):
                    pt = pt[1:]
L
lujun 已提交
187 188
                for fth_name in filenames:
                    if fth_name[0] != '.':
L
lujun 已提交
189 190 191 192
                        name_path = os.path.join(pt, fth_name)
                        if "\\" in name_path:
                            name_path = name_path.replace("\\", "/")
                        var_name_list.append(name_path)
193

L
lujun 已提交
194
        return var_name_list
195

L
lujun 已提交
196 197
    load_block = default_main_program().global_block()
    load_var_map = {}
198
    load_optimizer_map = {}
L
lujun 已提交
199 200 201
    file_var_list = walk_filename(file_dir)
    for var_name in file_var_list:
        new_var = Variable(block=load_block, name=var_name)
202
        load_block.append_op(
L
lujun 已提交
203
            type='load',
204
            inputs={},
L
lujun 已提交
205
            outputs={'Out': [new_var]},
206
            attrs={
L
lujun 已提交
207 208
                'file_path': os.path.join(file_dir,
                                          os.path.normpath(new_var.name))
209
            })
L
lujun 已提交
210 211

        load_var_map[new_var.name] = new_var
212 213 214 215 216 217 218 219 220 221 222 223 224
    opt_path = os.path.join(file_dir, "optimizers")
    for _, _, optimizers in os.walk(opt_path):
        for optimizer in optimizers:
            try:
                f = open(os.path.join(opt_path, optimizer), "rb")
                load_optimizer_map[optimizer] = pickle.load(f)
                f.close()
            except IOError:
                raise IOError("Can't load %s",
                              os.path.join(
                                  file_dir, "optimizers",
                                  os.path.normpath(str(optimizer._name))))
    if len(load_optimizer_map) == 0:
J
Jiabin Yang 已提交
225 226 227 228
        print(
            "No optimizer loaded. If you didn't save optimizer, please ignore this. The program can still work with new optimizer. "
        )
        pass
229 230

    return load_var_map, load_optimizer_map
231 232 233 234 235 236 237 238 239


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.dtype,
        type=var.type,
240
        lod_level=0,
241
        persistable=True)