save_load.py 5.4 KB
Newer Older
W
WuHaobo 已提交
1 2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
W
WuHaobo 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WuHaobo 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WuHaobo 已提交
19
import errno
W
WuHaobo 已提交
20
import os
W
WuHaobo 已提交
21
import re
W
WuHaobo 已提交
22
import shutil
W
WuHaobo 已提交
23
import tempfile
W
WuHaobo 已提交
24 25 26 27 28

import paddle.fluid as fluid

from ppcls.utils import logger

29
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
W
WuHaobo 已提交
30 31 32 33


def _mkdir_if_not_exist(path):
    """
W
WuHaobo 已提交
34
    mkdir if not exists, ignore the exception when multiprocess mkdir together
W
WuHaobo 已提交
35
    """
W
WuHaobo 已提交
36 37 38 39 40 41 42 43 44
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(path):
                logger.warning(
                    'be happy if some process has already created {}'.format(
                        path))
            else:
W
WuHaobo 已提交
45
                raise OSError('Failed to mkdir {}'.format(path))
W
WuHaobo 已提交
46 47


littletomatodonkey's avatar
littletomatodonkey 已提交
48
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
W
WuHaobo 已提交
49 50 51
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(path))
52 53 54 55 56 57 58 59 60 61 62 63 64 65
    if load_static_weights:
        pre_state_dict = fluid.load_program_state(path)
        param_state_dict = {}
        model_dict = model.state_dict()
        for key in model_dict.keys():
            weight_name = model_dict[key].name
            if weight_name in pre_state_dict.keys():
                print('Load weight: {}, shape: {}'.format(
                    weight_name, pre_state_dict[weight_name].shape))
                param_state_dict[key] = pre_state_dict[weight_name]
            else:
                param_state_dict[key] = model_dict[key]
        model.set_dict(param_state_dict)
        return
W
WuHaobo 已提交
66

67 68 69
    param_state_dict, optim_state_dict = fluid.load_dygraph(path)
    model.set_dict(param_state_dict)
    return
W
WuHaobo 已提交
70 71


littletomatodonkey's avatar
littletomatodonkey 已提交
72 73
def load_distillation_model(model, pretrained_model, load_static_weights):
    logger.info("In distillation mode, teacher model will be "
littletomatodonkey's avatar
littletomatodonkey 已提交
74 75 76 77 78 79 80 81
                "loaded firstly before student model.")
    assert len(pretrained_model
               ) == 2, "pretrained_model length should be 2 but got {}".format(
                   len(pretrained_model))
    assert len(
        load_static_weights
    ) == 2, "load_static_weights length should be 2 but got {}".format(
        len(load_static_weights))
littletomatodonkey's avatar
littletomatodonkey 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
    load_dygraph_pretrain(
        model.teacher,
        path=pretrained_model[0],
        load_static_weights=load_static_weights[0])
    logger.info(
        logger.coloring("Finish initing teacher model from {}".format(
            pretrained_model), "HEADER"))
    load_dygraph_pretrain(
        model.student,
        path=pretrained_model[1],
        load_static_weights=load_static_weights[1])
    logger.info(
        logger.coloring("Finish initing student model from {}".format(
            pretrained_model), "HEADER"))

littletomatodonkey's avatar
littletomatodonkey 已提交
97

98
def init_model(config, net, optimizer=None):
W
WuHaobo 已提交
99
    """
W
WuHaobo 已提交
100
    load model from checkpoint or pretrained_model
W
WuHaobo 已提交
101 102
    """
    checkpoints = config.get('checkpoints')
W
WuHaobo 已提交
103
    if checkpoints:
W
WuHaobo 已提交
104 105 106 107 108 109 110 111 112 113
        assert os.path.exists(checkpoints + ".pdparams"), \
            "Given dir {}.pdparams not exist.".format(checkpoints)
        assert os.path.exists(checkpoints + ".pdopt"), \
            "Given dir {}.pdopt not exist.".format(checkpoints)
        para_dict, opti_dict = fluid.dygraph.load_dygraph(checkpoints)
        net.set_dict(para_dict)
        optimizer.set_dict(opti_dict)
        logger.info(
            logger.coloring("Finish initing model from {}".format(checkpoints),
                            "HEADER"))
W
WuHaobo 已提交
114 115 116
        return

    pretrained_model = config.get('pretrained_model')
117 118
    load_static_weights = config.get('load_static_weights', False)
    use_distillation = config.get('use_distillation', False)
W
WuHaobo 已提交
119
    if pretrained_model:
littletomatodonkey's avatar
littletomatodonkey 已提交
120 121
        if isinstance(pretrained_model,
                      list):  # load distillation pretrained model
littletomatodonkey's avatar
littletomatodonkey 已提交
122
            if not isinstance(load_static_weights, list):
littletomatodonkey's avatar
littletomatodonkey 已提交
123 124
                load_static_weights = [load_static_weights] * len(
                    pretrained_model)
littletomatodonkey's avatar
littletomatodonkey 已提交
125
            load_distillation_model(net, pretrained_model, load_static_weights)
littletomatodonkey's avatar
littletomatodonkey 已提交
126
        else:  # common load
127
            load_dygraph_pretrain(
littletomatodonkey's avatar
littletomatodonkey 已提交
128 129 130
                net,
                path=pretrained_model,
                load_static_weights=load_static_weights)
131 132 133
            logger.info(
                logger.coloring("Finish initing model from {}".format(
                    pretrained_model), "HEADER"))
W
WuHaobo 已提交
134 135


W
WuHaobo 已提交
136
def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'):
W
WuHaobo 已提交
137
    """
W
WuHaobo 已提交
138
    save model to the target path
W
WuHaobo 已提交
139 140 141 142
    """
    model_path = os.path.join(model_path, str(epoch_id))
    _mkdir_if_not_exist(model_path)
    model_prefix = os.path.join(model_path, prefix)
W
WuHaobo 已提交
143 144 145 146 147 148

    fluid.dygraph.save_dygraph(net.state_dict(), model_prefix)
    fluid.dygraph.save_dygraph(optimizer.state_dict(), model_prefix)
    logger.info(
        logger.coloring("Already save model in {}".format(model_path),
                        "HEADER"))