From c3d0e3fb8af5b40b97b1eeb8c6f3dee39b5ec206 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Mon, 9 Nov 2020 17:33:49 +0800 Subject: [PATCH] check md5 before load weights. (#1663) * check md5 before load weights. --- ppdet/utils/download.py | 51 ++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/ppdet/utils/download.py b/ppdet/utils/download.py index 2ddb8be15..58f7be307 100644 --- a/ppdet/utils/download.py +++ b/ppdet/utils/download.py @@ -124,8 +124,8 @@ def get_dataset_path(path, annotation, image_dir): "Please apply and download the dataset from " "https://www.objects365.org/download.html".format(name)) data_dir = osp.join(DATASET_HOME, name) - # For voc, only check dir VOCdevkit/VOC2012, VOCdevkit/VOC2007 - if name == 'voc' or name == 'fruit' or name == 'roadsign_voc': + # For VOC-style datasets, only check subdirs + if name in ['voc', 'fruit', 'roadsign_voc']: exists = True for sub_dir in dataset[1]: check_dir = osp.join(data_dir, sub_dir) @@ -203,20 +203,26 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): if fullpath.find(k) >= 0: fullpath = osp.join(osp.split(fullpath)[0], v) - exist_flag = False if osp.exists(fullpath) and check_exist: - exist_flag = True - logger.debug("Found {}".format(fullpath)) - else: - exist_flag = False - fullname = _download(url, root_dir, md5sum) + # If fullpath is a directory, it has been decompressed + # checking MD5 is impossible, so we skip checking when + # fullpath is a directory here + if osp.isdir(fullpath) or \ + _md5check_from_req(fullpath, + requests.get(url, stream=True)): + logger.debug("Found {}".format(fullpath)) + return fullpath, True + else: + shutil.rmtree(fullpath) - # new weights format which postfix is 'pdparams' not - # need to decompress - if osp.splitext(fullname)[-1] != '.pdparams': - _decompress(fullname) + fullname = _download(url, root_dir, md5sum) - return fullpath, exist_flag + # new weights format whose postfix is 'pdparams', + # which is not need to decompress + if osp.splitext(fullname)[-1] != '.pdparams': + _decompress(fullname) + + return fullpath, False def download_dataset(path, dataset=None): @@ -308,11 +314,7 @@ def _download(url, path, md5sum=None): f.write(chunk) # check md5 after download in Content-MD5 in req.headers - content_md5 = req.headers.get('content-md5') - if not content_md5 or _md5check( - tmp_fullname, - binascii.hexlify(base64.b64decode(content_md5.strip( - '"'))).decode()): + if _md5check_from_req(tmp_fullname, req): shutil.move(tmp_fullname, fullname) return fullname else: @@ -322,6 +324,19 @@ def _download(url, path, md5sum=None): continue +def _md5check_from_req(weights_path, req): + # For weights in bcebos URLs, MD5 value is contained + # in request header as 'content_md5' + content_md5 = req.headers.get('content-md5') + if not content_md5 or _md5check( + weights_path, + binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode( + )): + return True + else: + return False + + def _md5check(fullname, md5sum=None): if md5sum is None: return True -- GitLab