checkpoint.py 3.5 KB
Newer Older
1 2 3 4 5
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

W
wangguanzhong 已提交
6
import errno
7
import os
8
import time
9
import re
F
FDInSky 已提交
10
import numpy as np
W
wangguanzhong 已提交
11
import paddle
12 13 14 15
import paddle.fluid as fluid
from .download import get_weights_path


F
FDInSky 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
def get_ckpt_path(path):
    if path.startswith('http://') or path.startswith('https://'):
        env = os.environ
        if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
            trainer_id = int(env['PADDLE_TRAINER_ID'])
            num_trainers = int(env['PADDLE_TRAINERS_NUM'])
            if num_trainers <= 1:
                path = get_weights_path(path)
            else:
                from ppdet.utils.download import map_path, WEIGHTS_HOME
                weight_path = map_path(path, WEIGHTS_HOME)
                lock_path = weight_path + '.lock'
                if not os.path.exists(weight_path):
                    try:
                        os.makedirs(os.path.dirname(weight_path))
                    except OSError as e:
                        if e.errno != errno.EEXIST:
                            raise
                    with open(lock_path, 'w'):  # touch    
                        os.utime(lock_path, None)
                    if trainer_id == 0:
                        get_weights_path(path)
                        os.remove(lock_path)
                    else:
                        while os.path.exists(lock_path):
                            time.sleep(1)
                path = weight_path
43
        else:
F
FDInSky 已提交
44
            path = get_weights_path(path)
45

K
Kaipeng Deng 已提交
46 47 48
    return path


F
FDInSky 已提交
49
def load_dygraph_ckpt(model,
50
                      optimizer=None,
F
FDInSky 已提交
51 52
                      pretrain_ckpt=None,
                      ckpt=None,
53
                      ckpt_type=None,
F
FDInSky 已提交
54
                      exclude_params=[],
55
                      load_static_weights=False):
F
FDInSky 已提交
56

57 58
    assert ckpt_type in ['pretrain', 'resume', 'finetune', None]
    if ckpt_type == 'pretrain' and ckpt is None:
F
FDInSky 已提交
59 60
        ckpt = pretrain_ckpt
    ckpt = get_ckpt_path(ckpt)
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
    assert os.path.exists(ckpt), "Path {} does not exist.".format(ckpt)
    if load_static_weights:
        pre_state_dict = fluid.load_program_state(ckpt)
        param_state_dict = {}
        model_dict = model.state_dict()
        for key in model_dict.keys():
            weight_name = model_dict[key].name
            if weight_name in pre_state_dict.keys():
                print('Load weight: {}, shape: {}'.format(
                    weight_name, pre_state_dict[weight_name].shape))
                param_state_dict[key] = pre_state_dict[weight_name]
            else:
                param_state_dict[key] = model_dict[key]
        model.set_dict(param_state_dict)
        return model
    param_state_dict, optim_state_dict = fluid.load_dygraph(ckpt)
F
FDInSky 已提交
77

78 79 80
    if len(exclude_params) != 0:
        for k in exclude_params:
            param_state_dict.pop(k, None)
F
FDInSky 已提交
81

82 83 84 85
    if ckpt_type == 'pretrain':
        model.backbone.set_dict(param_state_dict)
    else:
        model.set_dict(param_state_dict)
F
FDInSky 已提交
86

87 88 89
    if ckpt_type == 'resume':
        assert optim_state_dict, "Can't Resume Last Training's Optimizer State!!!"
        optimizer.set_dict(optim_state_dict)
F
FDInSky 已提交
90 91 92
    return model


W
wangguanzhong 已提交
93
def save_dygraph_ckpt(model, optimizer, save_dir, save_name):
F
FDInSky 已提交
94 95
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
W
wangguanzhong 已提交
96 97 98
    save_path = os.path.join(save_dir, save_name)
    fluid.dygraph.save_dygraph(model.state_dict(), save_path)
    fluid.dygraph.save_dygraph(optimizer.state_dict(), save_path)
F
FDInSky 已提交
99
    print("Save checkpoint:", save_dir)