未验证 提交 f24c1b05 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix re-download not trigger when weights update (#2617)

上级 f161d2a1
......@@ -23,6 +23,8 @@ import shutil
import requests
import tqdm
import hashlib
import base64
import binascii
import tarfile
import zipfile
......@@ -257,20 +259,21 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
if fullpath.find(k) >= 0:
fullpath = osp.join(osp.split(fullpath)[0], v)
exist_flag = False
if osp.exists(fullpath) and check_exist:
exist_flag = True
logger.debug("Found {}".format(fullpath))
else:
exist_flag = False
fullname = _download(url, root_dir, md5sum)
if _check_exist_file_md5(fullpath, md5sum, url):
logger.debug("Found {}".format(fullpath))
return fullpath, True
else:
os.remove(fullpath)
# new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname)
fullname = _download(url, root_dir, md5sum)
return fullpath, exist_flag
# new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname)
return fullpath, False
def download_dataset(path, dataset=None):
......@@ -324,7 +327,8 @@ def _download(url, path, md5sum=None):
fullname = osp.join(path, fname)
retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
url)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1
else:
......@@ -355,8 +359,30 @@ def _download(url, path, md5sum=None):
if chunk:
f.write(chunk)
shutil.move(tmp_fullname, fullname)
return fullname
return fullname
def _check_exist_file_md5(filename, md5sum, url):
# if md5sum is None, and file to check is weights file,
# read md5um from url and check, else check md5sum directly
return _md5check_from_url(filename, url) if md5sum is None \
and filename.endswith('pdparams') \
else _md5check(filename, md5sum)
def _md5check_from_url(filename, url):
# For weights in bcebos URLs, MD5 value is contained
# in request header as 'content_md5'
req = requests.get(url, stream=True)
content_md5 = req.headers.get('content-md5')
req.close()
if not content_md5 or _md5check(
filename,
binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
)):
return True
else:
return False
def _md5check(fullname, md5sum=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册