未验证 提交 73bbc913 编写于 作者: K Kaipeng Deng 提交者: GitHub

unify dist download (#3867)

上级 b63fe624
......@@ -55,41 +55,6 @@ def _get_unique_endpoints(trainer_endpoints):
return unique_endpoints
def get_weights_path_dist(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
path = get_weights_path(path)
return path
def _strip_postfix(path):
path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
......@@ -99,7 +64,7 @@ def _strip_postfix(path):
def load_weight(model, weight, optimizer=None):
if is_url(weight):
weight = get_weights_path_dist(weight)
weight = get_weights_path(weight)
path = _strip_postfix(weight)
pdparam_path = path + '.pdparams'
......@@ -205,7 +170,7 @@ def match_state_dict(model_state_dict, weight_state_dict):
def load_pretrain_weight(model, pretrain_weight):
if is_url(pretrain_weight):
pretrain_weight = get_weights_path_dist(pretrain_weight)
pretrain_weight = get_weights_path(pretrain_weight)
path = _strip_postfix(pretrain_weight)
if not (os.path.isdir(path) or os.path.isfile(path) or
......@@ -251,4 +216,4 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir))
\ No newline at end of file
logger.info("Save checkpoint: {}".format(save_dir))
......@@ -20,6 +20,7 @@ import os
import os.path as osp
import sys
import yaml
import time
import shutil
import requests
import tqdm
......@@ -29,6 +30,7 @@ import binascii
import tarfile
import zipfile
from paddle.utils.download import _get_unique_endpoints
from ppdet.core.workspace import BASE_KEY
from .logger import setup_logger
from .voc_utils import create_list
......@@ -144,8 +146,8 @@ def get_config_path(url):
cfg_url = parse_url(cfg_url)
# 3. download and decompress
cfg_fullname = _download(cfg_url, osp.dirname(CONFIGS_HOME))
_decompress(cfg_fullname)
cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME))
_decompress_dist(cfg_fullname)
# 4. check config file existing
if os.path.isfile(path):
......@@ -281,12 +283,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
else:
os.remove(fullpath)
fullname = _download(url, root_dir, md5sum)
fullname = _download_dist(url, root_dir, md5sum)
# new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname)
_decompress_dist(fullname)
return fullpath, False
......@@ -381,6 +383,38 @@ def _download(url, path, md5sum=None):
return fullname
def _download_dist(url, path, md5sum=None):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
return _download(url, path, md5sum)
else:
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
lock_path = fullname + '.download.lock'
if not osp.isdir(path):
os.makedirs(path)
if not osp.exists(fullname):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_download(url, path, md5sum)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
return fullname
else:
return _download(url, path, md5sum)
def _check_exist_file_md5(filename, md5sum, url):
# if md5sum is None, and file to check is weights file,
# read md5um from url and check, else check md5sum directly
......@@ -458,6 +492,30 @@ def _decompress(fname):
os.remove(fname)
def _decompress_dist(fname):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
_decompress(fname)
else:
lock_path = fname + '.decompress.lock'
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_decompress(fname)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
else:
_decompress(fname)
def _move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册