checkpoint.py 8.6 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
L
lujun 已提交
19
from ..framework import Variable, default_main_program
20 21 22
import pickle
from . import learning_rate_scheduler
import warnings
23 24 25 26

__all__ = ['save_persistables', 'load_persistables']


J
Jiabin Yang 已提交
27
def save_persistables(model_dict, dirname='save_dir', optimizers=None):
28
    """
29 30
    This function filters out all variables in layer.parameters from the give `layer`, and optimizer's learning rate decay.
    And then trys to save these variables to the folder `dirname`.
31 32

    Use the `dirname` to specify the folder where persistable variables were
33
    saved.
34 35

    Args:
36
        model_dict(dict of Parameters): The parameters will
37 38 39
                                    be saved. If it is None, nothing
                                    will be deal.
        dirname(str): The directory path.
J
Jiabin Yang 已提交
40
        optimizers(fluid.Optimizer|list(fluid.Optimizer)|None): The optimizers to be saved
41 42

    Returns:
L
lujun 已提交
43
        None
44 45

    Examples:
L
lujun 已提交
46

47
        .. code-block:: python
L
lujun 已提交
48 49

          ptb_model = PtbModel(
50 51 52 53 54
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale)
L
lujun 已提交
55 56 57 58 59 60
          sgd = fluid.optimizer.SGD(learning_rate=0.01)
          x_data = np.arange(12).reshape(4, 3).astype('int64')
          y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
          x_data = x_data.reshape((-1, num_steps, 1))
          y_data = y_data.reshape((-1, 1))
          init_hidden_data = np.zeros(
61
                (num_layers, batch_size, hidden_size), dtype='float32')
L
lujun 已提交
62
          init_cell_data = np.zeros(
63
                (num_layers, batch_size, hidden_size), dtype='float32')
L
lujun 已提交
64 65 66 67 68
          x = to_variable(x_data)
          y = to_variable(y_data)
          init_hidden = to_variable(init_hidden_data)
          init_cell = to_variable(init_cell_data)
          dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
69
                                                        init_cell)
L
lujun 已提交
70 71 72 73 74
          dy_loss.backward()
          sgd.minimize(dy_loss)
          ptb_model.clear_gradient()
          param_path = "./my_paddle_model"
          fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path, sgd)
75
    """
76
    if isinstance(model_dict, collections.OrderedDict):
J
Jiabin Yang 已提交
77
        _save_var_to_file(model_dict, optimizers, dirname, None)
78 79


80
def load_persistables(dirname='save_dir'):
81
    """
82 83
    This function trys to load persistable variables and optimizer's learning rate decay from the folder `dirname`.
    And return the restored values in a dictionary way, respectively.
84 85

    Use the `dirname` to specify the folder where persistable variables were
86
    saved.
87 88

    Args:
89
        dirname(str): The directory path. default is save_dir
90 91

    Returns:
92 93
        layer_dict: The parameter-dict resumed from file
        optimizer: The optimizer
94 95

    Examples:
L
lujun 已提交
96 97 98 99 100 101 102 103 104

         .. code-block:: python

           my_layer = layer(fluid.Layer)
           param_path = "./my_paddle_model"
           sgd = SGDOptimizer(learning_rate=1e-3)
           param_dict, optimizer_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path)
           param_1 = param_dict['PtbModel_0.w_1']
           sgd.load(optimizer_dict)
105 106

        """
L
lujun 已提交
107
    return _load_var_from_file(dirname)
108 109


110
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
111 112
    save_block = default_main_program().global_block()
    save_var_map = {}
113
    for var_key, each_var in stat_dict.items():
114 115 116 117 118 119
        save_var_map[each_var.name] = each_var
        if file_name is None:
            save_block.append_op(
                type='save',
                inputs={'X': [each_var]},
                outputs={},
120 121 122 123
                attrs={
                    'file_path': os.path.join(file_dir,
                                              os.path.normpath(each_var.name))
                })
J
Jiabin Yang 已提交
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153

    if optimizers is not None:
        if isinstance(optimizers, (list, tuple)):
            optimizers = optimizers
        else:
            optimizers = [optimizers]
        if os.path.exists(
                os.path.join(file_dir, os.path.normpath("optimizers"))):
            pass
        else:
            os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
        for optimizer in optimizers:
            if isinstance(optimizer._learning_rate,
                          learning_rate_scheduler.LearningRateDecay):
                try:
                    f = open(
                        os.path.join(file_dir, "optimizers",
                                     os.path.normpath(str(optimizer._name))),
                        "wb")
                    pickle.dump(optimizer._learning_rate, f, 2)
                    f.close()
                except ():
                    raise IOError("Can't load %s",
                                  os.path.join(
                                      file_dir, "optimizers",
                                      os.path.normpath(str(optimizer._name))))
            else:
                warnings.warn(
                    "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
                )
154 155
    else:
        pass
156 157 158 159 160 161 162 163 164 165

    if file_name is not None:
        save_var_list = []
        for name in sorted(save_var_map.keys()):
            save_var_list.append(save_var_map[name])

        save_block.append_op(
            type='save_combine',
            inputs={'X': save_var_list},
            outputs={},
166 167 168
            attrs={
                'file_path': os.path.join(file_dir, os.path.normpath(file_name))
            })
169 170


L
lujun 已提交
171
def _load_var_from_file(file_dir):
172 173 174
    if not os.path.exists(file_dir):
        raise IOError("{} not exist".format(file_dir))

L
lujun 已提交
175
    def walk_filename(file_dir):
L
lujun 已提交
176
        base_path = os.path.join(file_dir)
L
lujun 已提交
177
        var_name_list = []
L
lujun 已提交
178 179
        if os.path.exists(base_path):
            for dirpath, dirnames, filenames in os.walk(base_path):
180 181
                if "optimizers" in dirpath:
                    continue
L
lujun 已提交
182 183 184
                pt = dirpath.replace(base_path, "", 1)
                if pt.startswith("/") or pt.startswith("\\"):
                    pt = pt[1:]
L
lujun 已提交
185 186
                for fth_name in filenames:
                    if fth_name[0] != '.':
L
lujun 已提交
187 188 189 190
                        name_path = os.path.join(pt, fth_name)
                        if "\\" in name_path:
                            name_path = name_path.replace("\\", "/")
                        var_name_list.append(name_path)
191

L
lujun 已提交
192
        return var_name_list
193

L
lujun 已提交
194 195
    load_block = default_main_program().global_block()
    load_var_map = {}
196
    load_optimizer_map = {}
L
lujun 已提交
197 198 199
    file_var_list = walk_filename(file_dir)
    for var_name in file_var_list:
        new_var = Variable(block=load_block, name=var_name)
200
        load_block.append_op(
L
lujun 已提交
201
            type='load',
202
            inputs={},
L
lujun 已提交
203
            outputs={'Out': [new_var]},
204
            attrs={
L
lujun 已提交
205 206
                'file_path': os.path.join(file_dir,
                                          os.path.normpath(new_var.name))
207
            })
L
lujun 已提交
208 209

        load_var_map[new_var.name] = new_var
210 211 212 213 214 215 216 217 218 219 220 221 222
    opt_path = os.path.join(file_dir, "optimizers")
    for _, _, optimizers in os.walk(opt_path):
        for optimizer in optimizers:
            try:
                f = open(os.path.join(opt_path, optimizer), "rb")
                load_optimizer_map[optimizer] = pickle.load(f)
                f.close()
            except IOError:
                raise IOError("Can't load %s",
                              os.path.join(
                                  file_dir, "optimizers",
                                  os.path.normpath(str(optimizer._name))))
    if len(load_optimizer_map) == 0:
J
Jiabin Yang 已提交
223 224 225 226
        print(
            "No optimizer loaded. If you didn't save optimizer, please ignore this. The program can still work with new optimizer. "
        )
        pass
227 228

    return load_var_map, load_optimizer_map
229 230 231 232 233 234 235 236 237


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.dtype,
        type=var.type,
238
        lod_level=0,
239
        persistable=True)