未验证 提交 73bbc913 编写于 作者: K Kaipeng Deng 提交者: GitHub

unify dist download (#3867)

上级 b63fe624
...@@ -55,41 +55,6 @@ def _get_unique_endpoints(trainer_endpoints): ...@@ -55,41 +55,6 @@ def _get_unique_endpoints(trainer_endpoints):
return unique_endpoints return unique_endpoints
def get_weights_path_dist(path):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
path = get_weights_path(path)
else:
from ppdet.utils.download import map_path, WEIGHTS_HOME
weight_path = map_path(path, WEIGHTS_HOME)
lock_path = weight_path + '.lock'
if not os.path.exists(weight_path):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
try:
os.makedirs(os.path.dirname(weight_path))
except OSError as e:
if e.errno != errno.EEXIST:
raise
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
get_weights_path(path)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
path = weight_path
else:
path = get_weights_path(path)
return path
def _strip_postfix(path): def _strip_postfix(path):
path, ext = os.path.splitext(path) path, ext = os.path.splitext(path)
assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \ assert ext in ['', '.pdparams', '.pdopt', '.pdmodel'], \
...@@ -99,7 +64,7 @@ def _strip_postfix(path): ...@@ -99,7 +64,7 @@ def _strip_postfix(path):
def load_weight(model, weight, optimizer=None): def load_weight(model, weight, optimizer=None):
if is_url(weight): if is_url(weight):
weight = get_weights_path_dist(weight) weight = get_weights_path(weight)
path = _strip_postfix(weight) path = _strip_postfix(weight)
pdparam_path = path + '.pdparams' pdparam_path = path + '.pdparams'
...@@ -205,7 +170,7 @@ def match_state_dict(model_state_dict, weight_state_dict): ...@@ -205,7 +170,7 @@ def match_state_dict(model_state_dict, weight_state_dict):
def load_pretrain_weight(model, pretrain_weight): def load_pretrain_weight(model, pretrain_weight):
if is_url(pretrain_weight): if is_url(pretrain_weight):
pretrain_weight = get_weights_path_dist(pretrain_weight) pretrain_weight = get_weights_path(pretrain_weight)
path = _strip_postfix(pretrain_weight) path = _strip_postfix(pretrain_weight)
if not (os.path.isdir(path) or os.path.isfile(path) or if not (os.path.isdir(path) or os.path.isfile(path) or
...@@ -251,4 +216,4 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch): ...@@ -251,4 +216,4 @@ def save_model(model, optimizer, save_dir, save_name, last_epoch):
state_dict = optimizer.state_dict() state_dict = optimizer.state_dict()
state_dict['last_epoch'] = last_epoch state_dict['last_epoch'] = last_epoch
paddle.save(state_dict, save_path + ".pdopt") paddle.save(state_dict, save_path + ".pdopt")
logger.info("Save checkpoint: {}".format(save_dir)) logger.info("Save checkpoint: {}".format(save_dir))
\ No newline at end of file
...@@ -20,6 +20,7 @@ import os ...@@ -20,6 +20,7 @@ import os
import os.path as osp import os.path as osp
import sys import sys
import yaml import yaml
import time
import shutil import shutil
import requests import requests
import tqdm import tqdm
...@@ -29,6 +30,7 @@ import binascii ...@@ -29,6 +30,7 @@ import binascii
import tarfile import tarfile
import zipfile import zipfile
from paddle.utils.download import _get_unique_endpoints
from ppdet.core.workspace import BASE_KEY from ppdet.core.workspace import BASE_KEY
from .logger import setup_logger from .logger import setup_logger
from .voc_utils import create_list from .voc_utils import create_list
...@@ -144,8 +146,8 @@ def get_config_path(url): ...@@ -144,8 +146,8 @@ def get_config_path(url):
cfg_url = parse_url(cfg_url) cfg_url = parse_url(cfg_url)
# 3. download and decompress # 3. download and decompress
cfg_fullname = _download(cfg_url, osp.dirname(CONFIGS_HOME)) cfg_fullname = _download_dist(cfg_url, osp.dirname(CONFIGS_HOME))
_decompress(cfg_fullname) _decompress_dist(cfg_fullname)
# 4. check config file existing # 4. check config file existing
if os.path.isfile(path): if os.path.isfile(path):
...@@ -281,12 +283,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): ...@@ -281,12 +283,12 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
else: else:
os.remove(fullpath) os.remove(fullpath)
fullname = _download(url, root_dir, md5sum) fullname = _download_dist(url, root_dir, md5sum)
# new weights format which postfix is 'pdparams' not # new weights format which postfix is 'pdparams' not
# need to decompress # need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']: if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname) _decompress_dist(fullname)
return fullpath, False return fullpath, False
...@@ -381,6 +383,38 @@ def _download(url, path, md5sum=None): ...@@ -381,6 +383,38 @@ def _download(url, path, md5sum=None):
return fullname return fullname
def _download_dist(url, path, md5sum=None):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
return _download(url, path, md5sum)
else:
fname = osp.split(url)[-1]
fullname = osp.join(path, fname)
lock_path = fullname + '.download.lock'
if not osp.isdir(path):
os.makedirs(path)
if not osp.exists(fullname):
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_download(url, path, md5sum)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
return fullname
else:
return _download(url, path, md5sum)
def _check_exist_file_md5(filename, md5sum, url): def _check_exist_file_md5(filename, md5sum, url):
# if md5sum is None, and file to check is weights file, # if md5sum is None, and file to check is weights file,
# read md5um from url and check, else check md5sum directly # read md5um from url and check, else check md5sum directly
...@@ -458,6 +492,30 @@ def _decompress(fname): ...@@ -458,6 +492,30 @@ def _decompress(fname):
os.remove(fname) os.remove(fname)
def _decompress_dist(fname):
env = os.environ
if 'PADDLE_TRAINERS_NUM' in env and 'PADDLE_TRAINER_ID' in env:
trainer_id = int(env['PADDLE_TRAINER_ID'])
num_trainers = int(env['PADDLE_TRAINERS_NUM'])
if num_trainers <= 1:
_decompress(fname)
else:
lock_path = fname + '.decompress.lock'
from paddle.distributed import ParallelEnv
unique_endpoints = _get_unique_endpoints(ParallelEnv()
.trainer_endpoints[:])
with open(lock_path, 'w'): # touch
os.utime(lock_path, None)
if ParallelEnv().current_endpoint in unique_endpoints:
_decompress(fname)
os.remove(lock_path)
else:
while os.path.exists(lock_path):
time.sleep(1)
else:
_decompress(fname)
def _move_and_merge_tree(src, dst): def _move_and_merge_tree(src, dst):
""" """
Move src directory to dst, if dst is already exists, Move src directory to dst, if dst is already exists,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册