checkpoint.py 3.0 KB
Newer Older
1 2 3 4 5 6
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
7
import time
8
import re
F
FDInSky 已提交
9
import numpy as np
10 11 12 13
import paddle.fluid as fluid
from .download import get_weights_path


F
FDInSky 已提交
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
def get_ckpt_path(path):
    if path.startswith('http://') or path.startswith('https://'):
        env = os.environ
        if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
            trainer_id = int(env['PADDLE_TRAINER_ID'])
            num_trainers = int(env['PADDLE_TRAINERS_NUM'])
            if num_trainers <= 1:
                path = get_weights_path(path)
            else:
                from ppdet.utils.download import map_path, WEIGHTS_HOME
                weight_path = map_path(path, WEIGHTS_HOME)
                lock_path = weight_path + '.lock'
                if not os.path.exists(weight_path):
                    try:
                        os.makedirs(os.path.dirname(weight_path))
                    except OSError as e:
                        if e.errno != errno.EEXIST:
                            raise
                    with open(lock_path, 'w'):  # touch    
                        os.utime(lock_path, None)
                    if trainer_id == 0:
                        get_weights_path(path)
                        os.remove(lock_path)
                    else:
                        while os.path.exists(lock_path):
                            time.sleep(1)
                path = weight_path
41
        else:
F
FDInSky 已提交
42
            path = get_weights_path(path)
43

K
Kaipeng Deng 已提交
44 45 46
    return path


F
FDInSky 已提交
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
def load_dygraph_ckpt(model,
                      optimizer,
                      pretrain_ckpt=None,
                      ckpt=None,
                      ckpt_type='pretrain',
                      exclude_params=[],
                      open_debug=False):

    if ckpt_type == 'pretrain':
        ckpt = pretrain_ckpt
    ckpt = get_ckpt_path(ckpt)
    if ckpt is not None and os.path.exists(ckpt):
        param_state_dict, optim_state_dict = fluid.load_dygraph(ckpt)
        if open_debug:
            print("Loading Weights: ", param_state_dict.keys())

        if len(exclude_params) != 0:
            for k in exclude_params:
                param_state_dict.pop(k, None)

        if ckpt_type == 'pretrain':
            model.backbone.set_dict(param_state_dict)
        elif ckpt_type == 'finetune':
            model.set_dict(param_state_dict, use_structured_name=True)
        else:
            model.set_dict(param_state_dict)

        if ckpt_type == 'resume':
            if optim_state_dict is None:
                print("Can't Resume Last Training's Optimizer State!!!")
            else:
                optimizer.set_dict(optim_state_dict)
    return model


def save_dygraph_ckpt(model, optimizer, save_dir):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    fluid.dygraph.save_dygraph(model.state_dict(), save_dir)
    fluid.dygraph.save_dygraph(optimizer.state_dict(), save_dir)
    print("Save checkpoint:", save_dir)