save_load.py 5.6 KB
Newer Older
W
WuHaobo 已提交
1 2
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
W
WuHaobo 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
W
WuHaobo 已提交
6 7 8
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
W
WuHaobo 已提交
9 10 11 12 13
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
W
WuHaobo 已提交
14 15 16 17 18

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

W
WuHaobo 已提交
19
import errno
W
WuHaobo 已提交
20 21
import os

22
import paddle
W
WuHaobo 已提交
23
from ppcls.utils import logger
24
from .download import get_weights_path_from_url
W
WuHaobo 已提交
25

26
__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain']
W
WuHaobo 已提交
27 28 29 30


def _mkdir_if_not_exist(path):
    """
W
WuHaobo 已提交
31
    mkdir if not exists, ignore the exception when multiprocess mkdir together
W
WuHaobo 已提交
32
    """
W
WuHaobo 已提交
33 34 35 36 37 38 39 40 41
    if not os.path.exists(path):
        try:
            os.makedirs(path)
        except OSError as e:
            if e.errno == errno.EEXIST and os.path.isdir(path):
                logger.warning(
                    'be happy if some process has already created {}'.format(
                        path))
            else:
W
WuHaobo 已提交
42
                raise OSError('Failed to mkdir {}'.format(path))
W
WuHaobo 已提交
43 44


45
def load_dygraph_pretrain(model, path=None):
W
WuHaobo 已提交
46
    if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')):
47
        raise ValueError("Model pretrain path {}.pdparams does not "
W
WuHaobo 已提交
48
                         "exists.".format(path))
49
    param_state_dict = paddle.load(path + ".pdparams")
50 51
    if isinstance(model, list):
        for m in model:
52 53
            if hasattr(m, 'set_dict'):
                m.set_dict(param_state_dict)
54 55
    else:
        model.set_dict(param_state_dict)
56
    return
W
WuHaobo 已提交
57 58


C
cuicheng01 已提交
59
def load_dygraph_pretrain_from_url(model, pretrained_url, use_ssld=False):
60
    if use_ssld:
61 62 63 64
        pretrained_url = pretrained_url.replace("_pretrained",
                                                "_ssld_pretrained")
    local_weight_path = get_weights_path_from_url(pretrained_url).replace(
        ".pdparams", "")
65
    load_dygraph_pretrain(model, path=local_weight_path)
66 67 68
    return


69
def load_distillation_model(model, pretrained_model):
littletomatodonkey's avatar
littletomatodonkey 已提交
70
    logger.info("In distillation mode, teacher model will be "
littletomatodonkey's avatar
littletomatodonkey 已提交
71
                "loaded firstly before student model.")
72 73 74 75

    if not isinstance(pretrained_model, list):
        pretrained_model = [pretrained_model]

76 77 78 79
    teacher = model.teacher if hasattr(model,
                                       "teacher") else model._layers.teacher
    student = model.student if hasattr(model,
                                       "student") else model._layers.student
80
    load_dygraph_pretrain(teacher, path=pretrained_model[0])
81 82 83 84
    logger.info("Finish initing teacher model from {}".format(
        pretrained_model))
    # load student model
    if len(pretrained_model) >= 2:
85
        load_dygraph_pretrain(student, path=pretrained_model[1])
86 87
        logger.info("Finish initing student model from {}".format(
            pretrained_model))
littletomatodonkey's avatar
littletomatodonkey 已提交
88

littletomatodonkey's avatar
littletomatodonkey 已提交
89

90
def init_model(config, net, optimizer=None, loss: paddle.nn.Layer=None):
W
WuHaobo 已提交
91
    """
W
WuHaobo 已提交
92
    load model from checkpoint or pretrained_model
W
WuHaobo 已提交
93 94
    """
    checkpoints = config.get('checkpoints')
L
littletomatodonkey 已提交
95
    if checkpoints and optimizer is not None:
W
WuHaobo 已提交
96 97 98 99
        assert os.path.exists(checkpoints + ".pdparams"), \
            "Given dir {}.pdparams not exist.".format(checkpoints)
        assert os.path.exists(checkpoints + ".pdopt"), \
            "Given dir {}.pdopt not exist.".format(checkpoints)
100
        # load state dict
101
        opti_dict = paddle.load(checkpoints + ".pdopt")
102
        para_dict = paddle.load(checkpoints + ".pdparams")
103
        metric_dict = paddle.load(checkpoints + ".pdstates")
104 105
        # set state dict
        net.set_state_dict(para_dict)
H
HydrogenSulfate 已提交
106
        loss.set_state_dict(para_dict)
107 108
        for i in range(len(optimizer)):
            optimizer[i].set_state_dict(opti_dict)
L
littletomatodonkey 已提交
109
        logger.info("Finish load checkpoints from {}".format(checkpoints))
110
        return metric_dict
W
WuHaobo 已提交
111 112

    pretrained_model = config.get('pretrained_model')
113
    use_distillation = config.get('use_distillation', False)
W
WuHaobo 已提交
114
    if pretrained_model:
115
        if use_distillation:
116
            load_distillation_model(net, pretrained_model)
littletomatodonkey's avatar
littletomatodonkey 已提交
117
        else:  # common load
118
            load_dygraph_pretrain(net, path=pretrained_model)
119
            logger.info(
L
littletomatodonkey 已提交
120
                logger.coloring("Finish load pretrained model from {}".format(
121
                    pretrained_model), "HEADER"))
W
WuHaobo 已提交
122 123


124 125 126 127 128
def save_model(net,
               optimizer,
               metric_info,
               model_path,
               model_name="",
129 130
               prefix='ppcls',
               loss: paddle.nn.Layer=None):
W
WuHaobo 已提交
131
    """
W
WuHaobo 已提交
132
    save model to the target path
W
WuHaobo 已提交
133
    """
134 135
    if paddle.distributed.get_rank() != 0:
        return
L
littletomatodonkey 已提交
136
    model_path = os.path.join(model_path, model_name)
W
WuHaobo 已提交
137
    _mkdir_if_not_exist(model_path)
L
littletomatodonkey 已提交
138
    model_path = os.path.join(model_path, prefix)
139

140 141 142 143 144 145 146 147 148
    params_state_dict = net.state_dict()
    loss_state_dict = loss.state_dict()
    keys_inter = set(params_state_dict.keys()) & set(loss_state_dict.keys())
    assert len(keys_inter) == 0, \
        f"keys in model and loss state_dict must be unique, but got intersection {keys_inter}"
    params_state_dict.update(loss_state_dict)

    paddle.save(params_state_dict, model_path + ".pdparams")
    paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
L
littletomatodonkey 已提交
149
    paddle.save(metric_info, model_path + ".pdstates")
150
    logger.info("Already save model in {}".format(model_path))