diff --git a/ppdet/utils/checkpoint.py b/ppdet/utils/checkpoint.py index 329fa2194e6c54278c205ab2e6eaebfa011b4472..28b4608ac7776aa5fe9e7348b267541c49b11e94 100644 --- a/ppdet/utils/checkpoint.py +++ b/ppdet/utils/checkpoint.py @@ -55,41 +55,6 @@ def _get_unique_endpoints(trainer_endpoints): return unique_endpoints -def get_weights_path_dist(path): - env = os.environ - if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: - trainer_id = int(env['PADDLE_TRAINER_ID']) - num_trainers = int(env['PADDLE_TRAINERS_NUM']) - if num_trainers <= 1: - path = get_weights_path(path) - else: - from ppdet.utils.download import map_path, WEIGHTS_HOME - weight_path = map_path(path, WEIGHTS_HOME) - lock_path = weight_path + '.lock' - if not os.path.exists(weight_path): - from paddle.distributed import ParallelEnv - unique_endpoints = _get_unique_endpoints(ParallelEnv() - .trainer_endpoints[:]) - try: - os.makedirs(os.path.dirname(weight_path)) - except OSError as e: - if e.errno != errno.EEXIST: - raise - with open(lock_path, 'w'): # touch - os.utime(lock_path, None) - if ParallelEnv().current_endpoint in unique_endpoints: - get_weights_path(path) - os.remove(lock_path) - else: - while os.path.exists(lock_path): - time.sleep(1) - path = weight_path - else: - path = get_weights_path(path) - - return path - - def _strip_postfix(path): path, ext = os.path.splitext(path) assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ @@ -99,7 +64,7 @@ def _strip_postfix(path): def load_weight(model, weight, optimizer=None): if is_url(weight): - weight = get_weights_path_dist(weight) + weight = get_weights_path(weight) path = _strip_postfix(weight) pdparam_path = path + '.pdparams' @@ -205,7 +170,7 @@ def match_state_dict(model_state_dict, weight_state_dict): def load_pretrain_weight(model, pretrain_weight): if is_url(pretrain_weight): - pretrain_weight = get_weights_path_dist(pretrain_weight) + pretrain_weight = get_weights_path(pretrain_weight) path = _strip_postfix(pretrain_weight) if not (os.path.isdir(path) or os.path.isfile(path) or @@ -251,4 +216,4 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch): state_dict = optimizer.state_dict() state_dict['last_epoch'] = last_epoch paddle.save(state_dict, save_path + ".pdopt") - logger.info("Save checkpoint: {}".format(save_dir)) \ No newline at end of file + logger.info("Save checkpoint: {}".format(save_dir)) diff --git a/ppdet/utils/download.py b/ppdet/utils/download.py index a531732052ac2b2bd08e755139d15403e49a386a..4c4c27c9d2678c8222dd06515736a8b186c5135c 100644 --- a/ppdet/utils/download.py +++ b/ppdet/utils/download.py @@ -20,6 +20,7 @@ import os import os.path as osp import sys import yaml +import time import shutil import requests import tqdm @@ -29,6 +30,7 @@ import binascii import tarfile import zipfile +from paddle.utils.download import _get_unique_endpoints from ppdet.core.workspace import BASE_KEY from .logger import setup_logger from .voc_utils import create_list @@ -147,8 +149,8 @@ def get_config_path(url): cfg_url = parse_url(cfg_url) # 3. download and decompress - cfg_fullname = _download(cfg_url, osp.dirname(CONFIGS_HOME)) - _decompress(cfg_fullname) + cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME)) + _decompress_dist(cfg_fullname) # 4. check config file existing if os.path.isfile(path): @@ -284,12 +286,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): else: os.remove(fullpath) - fullname = _download(url, root_dir, md5sum) + fullname = _download_dist(url, root_dir, md5sum) # new weights format which postfix is 'pdparams' not # need to decompress if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']: - _decompress(fullname) + _decompress_dist(fullname) return fullpath, False @@ -384,6 +386,38 @@ def _download(url, path, md5sum=None): return fullname +def _download_dist(url, path, md5sum=None): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: + return _download(url, path, md5sum) + else: + fname = osp.split(url)[-1] + fullname = osp.join(path, fname) + lock_path = fullname + '.download.lock' + + if not osp.isdir(path): + os.makedirs(path) + + if not osp.exists(fullname): + from paddle.distributed import ParallelEnv + unique_endpoints = _get_unique_endpoints(ParallelEnv() + .trainer_endpoints[:]) + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if ParallelEnv().current_endpoint in unique_endpoints: + _download(url, path, md5sum) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + return fullname + else: + return _download(url, path, md5sum) + + def _check_exist_file_md5(filename, md5sum, url): # if md5sum is None, and file to check is weights file, # read md5um from url and check, else check md5sum directly @@ -461,6 +495,30 @@ def _decompress(fname): os.remove(fname) +def _decompress_dist(fname): + env = os.environ + if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env: + trainer_id = int(env['PADDLE_TRAINER_ID']) + num_trainers = int(env['PADDLE_TRAINERS_NUM']) + if num_trainers <= 1: + _decompress(fname) + else: + lock_path = fname + '.decompress.lock' + from paddle.distributed import ParallelEnv + unique_endpoints = _get_unique_endpoints(ParallelEnv() + .trainer_endpoints[:]) + with open(lock_path, 'w'): # touch + os.utime(lock_path, None) + if ParallelEnv().current_endpoint in unique_endpoints: + _decompress(fname) + os.remove(lock_path) + else: + while os.path.exists(lock_path): + time.sleep(1) + else: + _decompress(fname) + + def _move_and_merge_tree(src, dst): """ Move src directory to dst, if dst is already exists,