diff --git a/ppdet/utils/download.py b/ppdet/utils/download.py index 3b50ddd010d1263b30199350418f2d02c085f497..9e983efa6465eb0adf3fd0cd725a3f0495933111 100644 --- a/ppdet/utils/download.py +++ b/ppdet/utils/download.py @@ -23,6 +23,8 @@ import shutil import requests import tqdm import hashlib +import base64 +import binascii import tarfile import zipfile @@ -257,20 +259,21 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): if fullpath.find(k) >= 0: fullpath = osp.join(osp.split(fullpath)[0], v) - exist_flag = False if osp.exists(fullpath) and check_exist: - exist_flag = True - logger.debug("Found {}".format(fullpath)) - else: - exist_flag = False - fullname = _download(url, root_dir, md5sum) + if _check_exist_file_md5(fullpath, md5sum, url): + logger.debug("Found {}".format(fullpath)) + return fullpath, True + else: + os.remove(fullpath) - # new weights format which postfix is 'pdparams' not - # need to decompress - if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']: - _decompress(fullname) + fullname = _download(url, root_dir, md5sum) - return fullpath, exist_flag + # new weights format which postfix is 'pdparams' not + # need to decompress + if osp.splitext(fullname)[-1] not in ['.pdparams', '.yml']: + _decompress(fullname) + + return fullpath, False def download_dataset(path, dataset=None): @@ -324,7 +327,8 @@ def _download(url, path, md5sum=None): fullname = osp.join(path, fname) retry_cnt = 0 - while not (osp.exists(fullname) and _md5check(fullname, md5sum)): + while not (osp.exists(fullname) and _check_exist_file_md5(fullname, md5sum, + url)): if retry_cnt < DOWNLOAD_RETRY_LIMIT: retry_cnt += 1 else: @@ -355,8 +359,30 @@ def _download(url, path, md5sum=None): if chunk: f.write(chunk) shutil.move(tmp_fullname, fullname) - - return fullname + return fullname + + +def _check_exist_file_md5(filename, md5sum, url): + # if md5sum is None, and file to check is weights file, + # read md5um from url and check, else check md5sum directly + return _md5check_from_url(filename, url) if md5sum is None \ + and filename.endswith('pdparams') \ + else _md5check(filename, md5sum) + + +def _md5check_from_url(filename, url): + # For weights in bcebos URLs, MD5 value is contained + # in request header as 'content_md5' + req = requests.get(url, stream=True) + content_md5 = req.headers.get('content-md5') + req.close() + if not content_md5 or _md5check( + filename, + binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode( + )): + return True + else: + return False def _md5check(fullname, md5sum=None):