save_load.py 7.5 KB
Newer Older
W
WuHaobo 已提交
1 2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
W
WuHaobo 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WuHaobo 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WuHaobo 已提交
19
import errno
W
WuHaobo 已提交
20 21
import os

22
import paddle
R
root 已提交
23
from . import logger
24
from .download import get_weights_path_from_url
W
WuHaobo 已提交
25

26
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
W
WuHaobo 已提交
27 28


29
def load_dygraph_pretrain(model, path=None):
W
WuHaobo 已提交
30
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
31
        raise ValueError("Model pretrain path {}.pdparams does not "
W
WuHaobo 已提交
32
                         "exists.".format(path))
33
    param_state_dict = paddle.load(path + ".pdparams")
34 35
    if isinstance(model, list):
        for m in model:
36 37
            if hasattr(m, 'set_dict'):
                m.set_dict(param_state_dict)
38 39
    else:
        model.set_dict(param_state_dict)
40
    logger.info("Finish load pretrained model from {}".format(path))
41
    return
W
WuHaobo 已提交
42 43


weixin_46524038's avatar
weixin_46524038 已提交
44 45 46 47 48
def load_dygraph_pretrain_from_url(model,
                                   pretrained_url,
                                   use_ssld=False,
                                   use_imagenet22k_pretrained=False,
                                   use_imagenet22kto1k_pretrained=False):
49
    if use_ssld:
50 51
        pretrained_url = pretrained_url.replace("_pretrained",
                                                "_ssld_pretrained")
weixin_46524038's avatar
weixin_46524038 已提交
52 53 54 55 56 57
    if use_imagenet22k_pretrained:
        pretrained_url = pretrained_url.replace("_pretrained",
                                                "_22k_pretrained")
    if use_imagenet22kto1k_pretrained:
        pretrained_url = pretrained_url.replace("_pretrained",
                                                "_22kto1k_pretrained")
58 59
    local_weight_path = get_weights_path_from_url(pretrained_url).replace(
        ".pdparams", "")
60
    load_dygraph_pretrain(model, path=local_weight_path)
61 62 63
    return


64
def load_distillation_model(model, pretrained_model):
littletomatodonkey's avatar
littletomatodonkey 已提交
65
    logger.info("In distillation mode, teacher model will be "
littletomatodonkey's avatar
littletomatodonkey 已提交
66
                "loaded firstly before student model.")
67 68 69 70

    if not isinstance(pretrained_model, list):
        pretrained_model = [pretrained_model]

71 72 73 74
    teacher = model.teacher if hasattr(model,
                                       "teacher") else model._layers.teacher
    student = model.student if hasattr(model,
                                       "student") else model._layers.student
75
    load_dygraph_pretrain(teacher, path=pretrained_model[0])
76 77 78 79
    logger.info("Finish initing teacher model from {}".format(
        pretrained_model))
    # load student model
    if len(pretrained_model) >= 2:
80
        load_dygraph_pretrain(student, path=pretrained_model[1])
81 82
        logger.info("Finish initing student model from {}".format(
            pretrained_model))
littletomatodonkey's avatar
littletomatodonkey 已提交
83

littletomatodonkey's avatar
littletomatodonkey 已提交
84

F
flytocc 已提交
85 86 87 88
def init_model(config,
               net,
               optimizer=None,
               loss: paddle.nn.Layer=None,
G
gaotingquan 已提交
89
               model_ema=None):
W
WuHaobo 已提交
90
    """
W
WuHaobo 已提交
91
    load model from checkpoint or pretrained_model
W
WuHaobo 已提交
92 93
    """
    checkpoints = config.get('checkpoints')
L
littletomatodonkey 已提交
94
    if checkpoints and optimizer is not None:
W
WuHaobo 已提交
95 96 97 98
        assert os.path.exists(checkpoints + ".pdparams"), \
            "Given dir {}.pdparams not exist.".format(checkpoints)
        assert os.path.exists(checkpoints + ".pdopt"), \
            "Given dir {}.pdopt not exist.".format(checkpoints)
99
        # load state dict
100
        opti_dict = paddle.load(checkpoints + ".pdopt")
101
        para_dict = paddle.load(checkpoints + ".pdparams")
102
        metric_dict = paddle.load(checkpoints + ".pdstates")
103 104
        # set state dict
        net.set_state_dict(para_dict)
H
HydrogenSulfate 已提交
105
        loss.set_state_dict(para_dict)
106
        for i in range(len(optimizer)):
107 108
            optimizer[i].set_state_dict(opti_dict[i] if isinstance(
                opti_dict, list) else opti_dict)
G
gaotingquan 已提交
109
        if model_ema is not None:
F
flytocc 已提交
110 111 112
            assert os.path.exists(checkpoints + ".ema.pdparams"), \
                "Given dir {}.ema.pdparams not exist.".format(checkpoints)
            para_ema_dict = paddle.load(checkpoints + ".ema.pdparams")
G
gaotingquan 已提交
113
            model_ema.module.set_state_dict(para_ema_dict)
L
littletomatodonkey 已提交
114
        logger.info("Finish load checkpoints from {}".format(checkpoints))
115
        return metric_dict
W
WuHaobo 已提交
116 117

    pretrained_model = config.get('pretrained_model')
118
    use_distillation = config.get('use_distillation', False)
W
WuHaobo 已提交
119
    if pretrained_model:
120
        if use_distillation:
121
            load_distillation_model(net, pretrained_model)
littletomatodonkey's avatar
littletomatodonkey 已提交
122
        else:  # common load
123
            load_dygraph_pretrain(net, path=pretrained_model)
W
weishengyu 已提交
124
            logger.info("Finish load pretrained model from {}".format(
125
                pretrained_model))
G
gaotingquan 已提交
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152


def _mkdir_if_not_exist(path):
    """
    mkdir if not exists, ignore the exception when multiprocess mkdir together
    """
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(path):
                logger.warning(
                    'be happy if some process has already created {}'.format(
                        path))
            else:
                raise OSError('Failed to mkdir {}'.format(path))


def _extract_student_weights(all_params, student_prefix="Student."):
    s_params = {
        key[len(student_prefix):]: all_params[key]
        for key in all_params if student_prefix in key
    }
    return s_params


class ModelSaver(object):
G
gaotingquan 已提交
153
    def __init__(self, config, net, loss, opt, model_ema):
G
gaotingquan 已提交
154
        # net, loss, opt, model_ema, output_dir, 
G
gaotingquan 已提交
155 156 157 158 159 160 161 162
        self.net = net
        self.loss = loss
        self.opt = opt
        self.model_ema = model_ema

        arch_name = config["Arch"]["name"]
        self.output_dir = os.path.join(config["Global"]["output_dir"],
                                       arch_name)
G
gaotingquan 已提交
163 164 165 166 167 168 169 170 171
        _mkdir_if_not_exist(self.output_dir)

    def save(self, metric_info, prefix='ppcls', save_student_model=False):

        if paddle.distributed.get_rank() != 0:
            return

        save_dir = os.path.join(self.output_dir, prefix)

G
gaotingquan 已提交
172 173
        params_state_dict = self.net.state_dict()
        loss = self.loss
G
gaotingquan 已提交
174 175 176 177 178 179 180 181 182 183 184 185 186 187
        if loss is not None:
            loss_state_dict = loss.state_dict()
            keys_inter = set(params_state_dict.keys()) & set(
                loss_state_dict.keys())
            assert len(keys_inter) == 0, \
                f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
            params_state_dict.update(loss_state_dict)

        if save_student_model:
            s_params = _extract_student_weights(params_state_dict)
            if len(s_params) > 0:
                paddle.save(s_params, save_dir + "_student.pdparams")

        paddle.save(params_state_dict, save_dir + ".pdparams")
G
gaotingquan 已提交
188
        model_ema = self.model_ema
G
gaotingquan 已提交
189 190 191
        if model_ema is not None:
            paddle.save(model_ema.module.state_dict(),
                        save_dir + ".ema.pdparams")
G
gaotingquan 已提交
192
        optimizer = self.opt
G
gaotingquan 已提交
193 194 195 196
        paddle.save([opt.state_dict() for opt in optimizer],
                    save_dir + ".pdopt")
        paddle.save(metric_info, save_dir + ".pdstates")
        logger.info("Already save model in {}".format(save_dir))