未验证 提交 f24c1b05 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix re-download not trigger when weights update (#2617)

上级 f161d2a1
...@@ -23,6 +23,8 @@ import shutil ...@@ -23,6 +23,8 @@ import shutil
import requests import requests
import tqdm import tqdm
import hashlib import hashlib
import base64
import binascii
import tarfile import tarfile
import zipfile import zipfile
...@@ -257,20 +259,21 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): ...@@ -257,20 +259,21 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
if fullpath.find(k) >= 0: if fullpath.find(k) >= 0:
fullpath = osp.join(osp.split(fullpath)[0], v) fullpath = osp.join(osp.split(fullpath)[0], v)
exist_flag = False
if osp.exists(fullpath) and check_exist: if osp.exists(fullpath) and check_exist:
exist_flag = True if _check_exist_file_md5(fullpath, md5sum, url):
logger.debug("Found {}".format(fullpath)) logger.debug("Found {}".format(fullpath))
else: return fullpath, True
exist_flag = False else:
fullname = _download(url, root_dir, md5sum) os.remove(fullpath)
# new weights format which postfix is 'pdparams' not fullname = _download(url, root_dir, md5sum)
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname)
return fullpath, exist_flag # new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']:
_decompress(fullname)
return fullpath, False
def download_dataset(path, dataset=None): def download_dataset(path, dataset=None):
...@@ -324,7 +327,8 @@ def _download(url, path, md5sum=None): ...@@ -324,7 +327,8 @@ def _download(url, path, md5sum=None):
fullname = osp.join(path, fname) fullname = osp.join(path, fname)
retry_cnt = 0 retry_cnt = 0
while not (osp.exists(fullname) and _md5check(fullname, md5sum)): while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum,
url)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT: if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1 retry_cnt += 1
else: else:
...@@ -355,8 +359,30 @@ def _download(url, path, md5sum=None): ...@@ -355,8 +359,30 @@ def _download(url, path, md5sum=None):
if chunk: if chunk:
f.write(chunk) f.write(chunk)
shutil.move(tmp_fullname, fullname) shutil.move(tmp_fullname, fullname)
return fullname
return fullname
def _check_exist_file_md5(filename, md5sum, url):
# if md5sum is None, and file to check is weights file,
# read md5um from url and check, else check md5sum directly
return _md5check_from_url(filename, url) if md5sum is None \
and filename.endswith('pdparams') \
else _md5check(filename, md5sum)
def _md5check_from_url(filename, url):
# For weights in bcebos URLs, MD5 value is contained
# in request header as 'content_md5'
req = requests.get(url, stream=True)
content_md5 = req.headers.get('content-md5')
req.close()
if not content_md5 or _md5check(
filename,
binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
)):
return True
else:
return False
def _md5check(fullname, md5sum=None): def _md5check(fullname, md5sum=None):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册