checkpoint.py 8.4 KB
Newer Older
1
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import os
import collections
L
lujun 已提交
19
from ..framework import Variable, default_main_program
20 21 22
import pickle
from . import learning_rate_scheduler
import warnings
23 24 25 26

__all__ = ['save_persistables', 'load_persistables']


27 28 29 30
def save_persistables(model_dict,
                      optimizer=None,
                      dirname='save_dir',
                      filename=None):
31 32 33 34 35 36 37 38 39 40 41
    """
    This function filters out all variables in layer.parameters from the
    give `layer` and then trys to load these variables from the folder
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
42
        model_dict(dict of Parameters): The parameters will
43 44 45 46
                                    be saved. If it is None, nothing
                                    will be deal.
        dirname(str): The directory path.
        filename(str|None): The file which saved all variables. If variables were
47
                            saved in different files, set it to None.
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
                            Default: None

    Returns:

    Examples:
        .. code-block:: python
            ptb_model = PtbModel(
                hidden_size=hidden_size,
                vocab_size=vocab_size,
                num_layers=num_layers,
                num_steps=num_steps,
                init_scale=init_scale)

            x_data = np.arange(12).reshape(4, 3).astype('int64')
            y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
            x_data = x_data.reshape((-1, num_steps, 1))
            y_data = y_data.reshape((-1, 1))
            init_hidden_data = np.zeros(
                (num_layers, batch_size, hidden_size), dtype='float32')
            init_cell_data = np.zeros(
                (num_layers, batch_size, hidden_size), dtype='float32')
            x = to_variable(x_data)
            y = to_variable(y_data)
            init_hidden = to_variable(init_hidden_data)
            init_cell = to_variable(init_cell_data)
            dy_loss, last_hidden, last_cell = ptb_model(x, y, init_hidden,
                                                        init_cell)
            param_path = "./my_paddle_model"
L
lujun 已提交
76
            fluid.dygraph.save_persistables(ptb_model.state_dict(), dirname=param_path,
77 78
                                       layer=ptb_model)
    """
79 80
    if isinstance(model_dict, collections.OrderedDict):
        _save_var_to_file(model_dict, optimizer, dirname, filename)
81 82


83
def load_persistables(dirname='save_dir'):
84 85 86 87 88 89 90 91 92 93
    """
    This function trys to load persistable variables from the folder
    `dirname` or the file `filename`.

    Use the `dirname` to specify the folder where persistable variables were
    saved. If variables were saved in separate files, set `filename` None;
    if all variables were saved in a single file, use `filename` to specify
    the file name.

    Args:
94 95
        dirname(str): The directory path. default is save_dir
        optimizer(Optimizer): Optimizer to be saved
96 97 98 99 100 101

    Returns:
        dict: The parameter-dict resumed from file

    Examples:
        .. code-block:: python
102
            my_layer = layer(fluid.Layer)
103 104
            param_path = "./my_paddle_model"

L
lujun 已提交
105
            param_dict = fluid.dygraph.load_persistables(my_layer.parameters(), param_path)
106 107 108
            param_1 = param_dict['PtbModel_0.w_1']

        """
L
lujun 已提交
109
    return _load_var_from_file(dirname)
110 111


112
def _save_var_to_file(stat_dict, optimizers, file_dir, file_name):
113 114
    save_block = default_main_program().global_block()
    save_var_map = {}
115
    for var_key, each_var in stat_dict.items():
116 117 118 119 120 121
        save_var_map[each_var.name] = each_var
        if file_name is None:
            save_block.append_op(
                type='save',
                inputs={'X': [each_var]},
                outputs={},
122 123 124 125
                attrs={
                    'file_path': os.path.join(file_dir,
                                              os.path.normpath(each_var.name))
                })
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
    if isinstance(optimizers, (list, tuple)):
        optimizers = optimizers
    else:
        optimizers = [optimizers]
    if os.path.exists(os.path.join(file_dir, os.path.normpath("optimizers"))):
        pass
    else:
        os.mkdir(os.path.join(file_dir, os.path.normpath("optimizers")))
    for optimizer in optimizers:
        if isinstance(optimizer._learning_rate,
                      learning_rate_scheduler.LearningRateDecay):
            try:
                f = open(
                    os.path.join(file_dir, "optimizers",
                                 os.path.normpath(str(optimizer._name))), "wb")
                pickle.dump(optimizer._learning_rate, f, 2)
                f.close()
            except ():
                raise IOError("Can't load %s",
                              os.path.join(
                                  file_dir, "optimizers",
                                  os.path.normpath(str(optimizer._name))))
        else:
            warnings.warn(
                "Optimizer not saved, Only optimizer with 'LearningRateDecay' under DyGraph mode need to be saved"
            )
152 153 154 155 156 157 158 159 160 161

    if file_name is not None:
        save_var_list = []
        for name in sorted(save_var_map.keys()):
            save_var_list.append(save_var_map[name])

        save_block.append_op(
            type='save_combine',
            inputs={'X': save_var_list},
            outputs={},
162 163 164
            attrs={
                'file_path': os.path.join(file_dir, os.path.normpath(file_name))
            })
165 166


L
lujun 已提交
167 168
def _load_var_from_file(file_dir):
    def walk_filename(file_dir):
L
lujun 已提交
169
        base_path = os.path.join(file_dir)
L
lujun 已提交
170
        var_name_list = []
L
lujun 已提交
171 172
        if os.path.exists(base_path):
            for dirpath, dirnames, filenames in os.walk(base_path):
173 174
                if "optimizers" in dirpath:
                    continue
L
lujun 已提交
175 176 177
                pt = dirpath.replace(base_path, "", 1)
                if pt.startswith("/") or pt.startswith("\\"):
                    pt = pt[1:]
L
lujun 已提交
178 179
                for fth_name in filenames:
                    if fth_name[0] != '.':
L
lujun 已提交
180 181 182 183
                        name_path = os.path.join(pt, fth_name)
                        if "\\" in name_path:
                            name_path = name_path.replace("\\", "/")
                        var_name_list.append(name_path)
184

L
lujun 已提交
185
        return var_name_list
186

L
lujun 已提交
187 188
    load_block = default_main_program().global_block()
    load_var_map = {}
189
    load_optimizer_map = {}
L
lujun 已提交
190 191 192
    file_var_list = walk_filename(file_dir)
    for var_name in file_var_list:
        new_var = Variable(block=load_block, name=var_name)
193
        load_block.append_op(
L
lujun 已提交
194
            type='load',
195
            inputs={},
L
lujun 已提交
196
            outputs={'Out': [new_var]},
197
            attrs={
L
lujun 已提交
198 199
                'file_path': os.path.join(file_dir,
                                          os.path.normpath(new_var.name))
200
            })
L
lujun 已提交
201 202

        load_var_map[new_var.name] = new_var
203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
    opt_path = os.path.join(file_dir, "optimizers")
    for _, _, optimizers in os.walk(opt_path):
        for optimizer in optimizers:
            try:
                f = open(os.path.join(opt_path, optimizer), "rb")
                load_optimizer_map[optimizer] = pickle.load(f)
                f.close()
            except IOError:
                raise IOError("Can't load %s",
                              os.path.join(
                                  file_dir, "optimizers",
                                  os.path.normpath(str(optimizer._name))))
    if len(load_optimizer_map) == 0:
        warnings.warn("No optimizer loaded")

    return load_var_map, load_optimizer_map
219 220 221 222 223 224 225 226 227


def _clone_var_in_block_(block, var):
    assert isinstance(var, Variable)
    return block.create_var(
        name=var.name,
        shape=var.shape,
        dtype=var.dtype,
        type=var.type,
228
        lod_level=0,
229
        persistable=True)