save_load.py 6.7 KB
Newer Older
W
WuHaobo 已提交
1 2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
W
WuHaobo 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WuHaobo 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WuHaobo 已提交
19
import errno
W
WuHaobo 已提交
20
import os
W
WuHaobo 已提交
21
import re
W
WuHaobo 已提交
22
import shutil
W
WuHaobo 已提交
23
import tempfile
W
WuHaobo 已提交
24

25
import paddle
26
from paddle.static import load_program_state
27
from paddle.utils.download import get_weights_path_from_url
W
WuHaobo 已提交
28 29 30

from ppcls.utils import logger

31
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
W
WuHaobo 已提交
32 33 34 35


def _mkdir_if_not_exist(path):
    """
W
WuHaobo 已提交
36
    mkdir if not exists, ignore the exception when multiprocess mkdir together
W
WuHaobo 已提交
37
    """
W
WuHaobo 已提交
38 39 40 41 42 43 44 45 46
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(path):
                logger.warning(
                    'be happy if some process has already created {}'.format(
                        path))
            else:
W
WuHaobo 已提交
47
                raise OSError('Failed to mkdir {}'.format(path))
W
WuHaobo 已提交
48 49


littletomatodonkey's avatar
littletomatodonkey 已提交
50
def load_dygraph_pretrain(model, path=None, load_static_weights=False):
W
WuHaobo 已提交
51 52 53
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(path))
54
    if load_static_weights:
55
        pre_state_dict = load_program_state(path)
56 57 58 59 60
        param_state_dict = {}
        model_dict = model.state_dict()
        for key in model_dict.keys():
            weight_name = model_dict[key].name
            if weight_name in pre_state_dict.keys():
L
littletomatodonkey 已提交
61
                logger.info('Load weight: {}, shape: {}'.format(
62 63 64 65 66 67
                    weight_name, pre_state_dict[weight_name].shape))
                param_state_dict[key] = pre_state_dict[weight_name]
            else:
                param_state_dict[key] = model_dict[key]
        model.set_dict(param_state_dict)
        return
W
WuHaobo 已提交
68

69
    param_state_dict = paddle.load(path + ".pdparams")
70 71
    model.set_dict(param_state_dict)
    return
W
WuHaobo 已提交
72 73


74 75 76 77
def load_dygraph_pretrain_from_url(model,
                                   pretrained_url,
                                   use_ssld,
                                   load_static_weights=False):
78
    if use_ssld:
79 80 81 82 83 84
        pretrained_url = pretrained_url.replace("_pretrained",
                                                "_ssld_pretrained")
    local_weight_path = get_weights_path_from_url(pretrained_url).replace(
        ".pdparams", "")
    load_dygraph_pretrain(
        model, path=local_weight_path, load_static_weights=load_static_weights)
85 86 87
    return


littletomatodonkey's avatar
littletomatodonkey 已提交
88 89
def load_distillation_model(model, pretrained_model, load_static_weights):
    logger.info("In distillation mode, teacher model will be "
littletomatodonkey's avatar
littletomatodonkey 已提交
90
                "loaded firstly before student model.")
91 92 93 94 95

    if not isinstance(pretrained_model, list):
        pretrained_model = [pretrained_model]

    if not isinstance(load_static_weights, list):
96
        load_static_weights = [load_static_weights] * len(pretrained_model)
97

98 99 100 101
    teacher = model.teacher if hasattr(model,
                                       "teacher") else model._layers.teacher
    student = model.student if hasattr(model,
                                       "student") else model._layers.student
littletomatodonkey's avatar
littletomatodonkey 已提交
102
    load_dygraph_pretrain(
103
        teacher,
littletomatodonkey's avatar
littletomatodonkey 已提交
104 105
        path=pretrained_model[0],
        load_static_weights=load_static_weights[0])
106 107 108 109 110 111 112 113 114 115
    logger.info("Finish initing teacher model from {}".format(
        pretrained_model))
    # load student model
    if len(pretrained_model) >= 2:
        load_dygraph_pretrain(
            student,
            path=pretrained_model[1],
            load_static_weights=load_static_weights[1])
        logger.info("Finish initing student model from {}".format(
            pretrained_model))
littletomatodonkey's avatar
littletomatodonkey 已提交
116

littletomatodonkey's avatar
littletomatodonkey 已提交
117

118
def init_model(config, net, optimizer=None):
W
WuHaobo 已提交
119
    """
W
WuHaobo 已提交
120
    load model from checkpoint or pretrained_model
W
WuHaobo 已提交
121 122
    """
    checkpoints = config.get('checkpoints')
L
littletomatodonkey 已提交
123
    if checkpoints and optimizer is not None:
W
WuHaobo 已提交
124 125 126 127
        assert os.path.exists(checkpoints + ".pdparams"), \
            "Given dir {}.pdparams not exist.".format(checkpoints)
        assert os.path.exists(checkpoints + ".pdopt"), \
            "Given dir {}.pdopt not exist.".format(checkpoints)
128 129
        para_dict = paddle.load(checkpoints + ".pdparams")
        opti_dict = paddle.load(checkpoints + ".pdopt")
130
        metric_dict = paddle.load(checkpoints + ".pdstates")
W
WuHaobo 已提交
131
        net.set_dict(para_dict)
littletomatodonkey's avatar
littletomatodonkey 已提交
132
        optimizer.set_state_dict(opti_dict)
L
littletomatodonkey 已提交
133
        logger.info("Finish load checkpoints from {}".format(checkpoints))
134
        return metric_dict
W
WuHaobo 已提交
135 136

    pretrained_model = config.get('pretrained_model')
137 138
    load_static_weights = config.get('load_static_weights', False)
    use_distillation = config.get('use_distillation', False)
W
WuHaobo 已提交
139
    if pretrained_model:
140
        if use_distillation:
littletomatodonkey's avatar
littletomatodonkey 已提交
141
            load_distillation_model(net, pretrained_model, load_static_weights)
littletomatodonkey's avatar
littletomatodonkey 已提交
142
        else:  # common load
143
            load_dygraph_pretrain(
littletomatodonkey's avatar
littletomatodonkey 已提交
144 145 146
                net,
                path=pretrained_model,
                load_static_weights=load_static_weights)
147
            logger.info(
L
littletomatodonkey 已提交
148
                logger.coloring("Finish load pretrained model from {}".format(
149
                    pretrained_model), "HEADER"))
W
WuHaobo 已提交
150 151


152 153 154 155 156 157 158 159 160 161 162 163 164
def _save_student_model(net, model_prefix):
    """
    save student model if the net is the network contains student
    """
    student_model_prefix = model_prefix + "_student.pdparams"
    if hasattr(net, "_layers"):
        net = net._layers
    if hasattr(net, "student"):
        paddle.save(net.student.state_dict(), student_model_prefix)
        logger.info("Already save student model in {}".format(
            student_model_prefix))


165 166 167 168 169 170
def save_model(net,
               optimizer,
               metric_info,
               model_path,
               model_name="",
               prefix='ppcls'):
W
WuHaobo 已提交
171
    """
W
WuHaobo 已提交
172
    save model to the target path
W
WuHaobo 已提交
173
    """
174 175
    if paddle.distributed.get_rank() != 0:
        return
L
littletomatodonkey 已提交
176
    model_path = os.path.join(model_path, model_name)
W
WuHaobo 已提交
177 178
    _mkdir_if_not_exist(model_path)
    model_prefix = os.path.join(model_path, prefix)
W
WuHaobo 已提交
179

180 181
    _save_student_model(net, model_prefix)

182 183
    paddle.save(net.state_dict(), model_prefix + ".pdparams")
    paddle.save(optimizer.state_dict(), model_prefix + ".pdopt")
184
    paddle.save(metric_info, model_prefix + ".pdstates")
185
    logger.info("Already save model in {}".format(model_path))