checkpoint.py 7.3 KB
Newer Older
Q
qingqing01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import errno
import os
import time
import re
import numpy as np
import paddle
W
wangxinxin08 已提交
26
import paddle.nn as nn
Q
qingqing01 已提交
27 28 29 30 31 32 33 34 35 36 37 38
from .download import get_weights_path

from .logger import setup_logger
logger = setup_logger(__name__)


def is_url(path):
    """
    Whether path is URL.
    Args:
        path (string): URL string or not.
    """
K
Kaipeng Deng 已提交
39 40 41
    return path.startswith('http://') \
            or path.startswith('https://') \
            or path.startswith('ppdet://')
Q
qingqing01 已提交
42 43


44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def _get_unique_endpoints(trainer_endpoints):
    # Sorting is to avoid different environmental variables for each card
    trainer_endpoints.sort()
    ips = set()
    unique_endpoints = set()
    for endpoint in trainer_endpoints:
        ip = endpoint.split(":")[0]
        if ip in ips:
            continue
        ips.add(ip)
        unique_endpoints.add(endpoint)
    logger.info("unique_endpoints {}".format(unique_endpoints))
    return unique_endpoints


K
Kaipeng Deng 已提交
59
def get_weights_path_dist(path):
Q
qingqing01 已提交
60 61 62 63 64 65 66 67 68 69 70
    env = os.environ
    if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
        trainer_id = int(env['PADDLE_TRAINER_ID'])
        num_trainers = int(env['PADDLE_TRAINERS_NUM'])
        if num_trainers <= 1:
            path = get_weights_path(path)
        else:
            from ppdet.utils.download import map_path, WEIGHTS_HOME
            weight_path = map_path(path, WEIGHTS_HOME)
            lock_path = weight_path + '.lock'
            if not os.path.exists(weight_path):
71 72 73
                from paddle.distributed import ParallelEnv
                unique_endpoints = _get_unique_endpoints(ParallelEnv()
                                                         .trainer_endpoints[:])
Q
qingqing01 已提交
74 75 76 77 78 79 80
                try:
                    os.makedirs(os.path.dirname(weight_path))
                except OSError as e:
                    if e.errno != errno.EEXIST:
                        raise
                with open(lock_path, 'w'):  # touch    
                    os.utime(lock_path, None)
81
                if ParallelEnv().current_endpoint in unique_endpoints:
Q
qingqing01 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
                    get_weights_path(path)
                    os.remove(lock_path)
                else:
                    while os.path.exists(lock_path):
                        time.sleep(1)
            path = weight_path
    else:
        path = get_weights_path(path)

    return path


def _strip_postfix(path):
    path, ext = os.path.splitext(path)
    assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
            "Unknown postfix {} from weights".format(ext)
    return path


def load_weight(model, weight, optimizer=None):
    if is_url(weight):
K
Kaipeng Deng 已提交
103
        weight = get_weights_path_dist(weight)
Q
qingqing01 已提交
104 105 106 107 108 109 110 111

    path = _strip_postfix(weight)
    pdparam_path = path + '.pdparams'
    if not os.path.exists(pdparam_path):
        raise ValueError("Model pretrain path {} does not "
                         "exists.".format(pdparam_path))

    param_state_dict = paddle.load(pdparam_path)
112 113 114 115 116 117 118 119 120 121 122 123 124 125
    model_dict = model.state_dict()
    model_weight = {}
    incorrect_keys = 0

    for key in model_dict.keys():
        if key in param_state_dict.keys():
            model_weight[key] = param_state_dict[key]
        else:
            logger.info('Unmatched key: {}'.format(key))
            incorrect_keys += 1

    assert incorrect_keys == 0, "Load weight {} incorrectly, \
            {} keys unmatched, please check again.".format(weight,
                                                           incorrect_keys)
K
Kaipeng Deng 已提交
126
    logger.info('Finish resuming model weights: {}'.format(pdparam_path))
127 128

    model.set_dict(model_weight)
Q
qingqing01 已提交
129

G
Guanghua Yu 已提交
130
    last_epoch = 0
Q
qingqing01 已提交
131 132
    if optimizer is not None and os.path.exists(path + '.pdopt'):
        optim_state_dict = paddle.load(path + '.pdopt')
133
        # to solve resume bug, will it be fixed in paddle 2.0
Q
qingqing01 已提交
134 135 136 137 138 139
        for key in optimizer.state_dict().keys():
            if not key in optim_state_dict.keys():
                optim_state_dict[key] = optimizer.state_dict()[key]
        if 'last_epoch' in optim_state_dict:
            last_epoch = optim_state_dict.pop('last_epoch')
        optimizer.set_state_dict(optim_state_dict)
G
Guanghua Yu 已提交
140 141

    return last_epoch
Q
qingqing01 已提交
142 143


K
Kaipeng Deng 已提交
144
def load_pretrain_weight(model, pretrain_weight):
Q
qingqing01 已提交
145
    if is_url(pretrain_weight):
K
Kaipeng Deng 已提交
146
        pretrain_weight = get_weights_path_dist(pretrain_weight)
Q
qingqing01 已提交
147 148 149 150

    path = _strip_postfix(pretrain_weight)
    if not (os.path.isdir(path) or os.path.isfile(path) or
            os.path.exists(path + '.pdparams')):
151 152 153 154
        raise ValueError("Model pretrain path `{}` does not exists. "
                         "If you don't want to load pretrain model, "
                         "please delete `pretrain_weights` field in "
                         "config file.".format(path))
Q
qingqing01 已提交
155 156 157

    model_dict = model.state_dict()

K
Kaipeng Deng 已提交
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178
    weights_path = path + '.pdparams'
    param_state_dict = paddle.load(weights_path)
    ignore_set = set()
    lack_modules = set()
    for name, weight in model_dict.items():
        if name in param_state_dict.keys():
            if weight.shape != list(param_state_dict[name].shape):
                logger.info(
                    '{} not used, shape {} unmatched with {} in model.'.format(
                        name, list(param_state_dict[name].shape), weight.shape))
                param_state_dict.pop(name, None)
        else:
            lack_modules.add(name.split('.')[0])
            logger.debug('Lack weights: {}'.format(name))

    if len(lack_modules) > 0:
        logger.info('Lack weights of modules: {}'.format(', '.join(
            list(lack_modules))))

    model.set_dict(param_state_dict)
    logger.info('Finish loading model weights: {}'.format(weights_path))
Q
qingqing01 已提交
179 180 181 182 183


def save_model(model, optimizer, save_dir, save_name, last_epoch):
    """
    save model into disk.
184

Q
qingqing01 已提交
185 186 187 188 189 190 191 192
    Args:
        model (paddle.nn.Layer): the Layer instalce to save parameters.
        optimizer (paddle.optimizer.Optimizer): the Optimizer instance to
            save optimizer states.
        save_dir (str): the directory to be saved.
        save_name (str): the path to be saved.
        last_epoch (int): the epoch index.
    """
193 194
    if paddle.distributed.get_rank() != 0:
        return
Q
qingqing01 已提交
195 196 197
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    save_path = os.path.join(save_dir, save_name)
W
wangxinxin08 已提交
198 199 200 201 202 203
    if isinstance(model, nn.Layer):
        paddle.save(model.state_dict(), save_path + ".pdparams")
    else:
        assert isinstance(model,
                          dict), 'model is not a instance of nn.layer or dict'
        paddle.save(model, save_path + ".pdparams")
Q
qingqing01 已提交
204 205 206 207
    state_dict = optimizer.state_dict()
    state_dict['last_epoch'] = last_epoch
    paddle.save(state_dict, save_path + ".pdopt")
    logger.info("Save checkpoint: {}".format(save_dir))