diff --git a/ppdet/utils/download.py b/ppdet/utils/download.py index 2ddb8be1564405f79cdecbd1d67b01040e6d3d81..58f7be307733b26c6a366ad8f9f54adfb75926e1 100644 --- a/ppdet/utils/download.py +++ b/ppdet/utils/download.py @@ -124,8 +124,8 @@ def get_dataset_path(path, annotation, image_dir): "Please apply and download the dataset from " "https://www.objects365.org/download.html".format(name)) data_dir = osp.join(DATASET_HOME, name) - # For voc, only check dir VOCdevkit/VOC2012, VOCdevkit/VOC2007 - if name == 'voc' or name == 'fruit' or name == 'roadsign_voc': + # For VOC-style datasets, only check subdirs + if name in ['voc', 'fruit', 'roadsign_voc']: exists = True for sub_dir in dataset[1]: check_dir = osp.join(data_dir, sub_dir) @@ -203,20 +203,26 @@ def get_path(url, root_dir, md5sum=None, check_exist=True): if fullpath.find(k) >= 0: fullpath = osp.join(osp.split(fullpath)[0], v) - exist_flag = False if osp.exists(fullpath) and check_exist: - exist_flag = True - logger.debug("Found {}".format(fullpath)) - else: - exist_flag = False - fullname = _download(url, root_dir, md5sum) + # If fullpath is a directory, it has been decompressed + # checking MD5 is impossible, so we skip checking when + # fullpath is a directory here + if osp.isdir(fullpath) or \ + _md5check_from_req(fullpath, + requests.get(url, stream=True)): + logger.debug("Found {}".format(fullpath)) + return fullpath, True + else: + shutil.rmtree(fullpath) - # new weights format which postfix is 'pdparams' not - # need to decompress - if osp.splitext(fullname)[-1] != '.pdparams': - _decompress(fullname) + fullname = _download(url, root_dir, md5sum) - return fullpath, exist_flag + # new weights format whose postfix is 'pdparams', + # which is not need to decompress + if osp.splitext(fullname)[-1] != '.pdparams': + _decompress(fullname) + + return fullpath, False def download_dataset(path, dataset=None): @@ -308,11 +314,7 @@ def _download(url, path, md5sum=None): f.write(chunk) # check md5 after download in Content-MD5 in req.headers - content_md5 = req.headers.get('content-md5') - if not content_md5 or _md5check( - tmp_fullname, - binascii.hexlify(base64.b64decode(content_md5.strip( - '"'))).decode()): + if _md5check_from_req(tmp_fullname, req): shutil.move(tmp_fullname, fullname) return fullname else: @@ -322,6 +324,19 @@ def _download(url, path, md5sum=None): continue +def _md5check_from_req(weights_path, req): + # For weights in bcebos URLs, MD5 value is contained + # in request header as 'content_md5' + content_md5 = req.headers.get('content-md5') + if not content_md5 or _md5check( + weights_path, + binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode( + )): + return True + else: + return False + + def _md5check(fullname, md5sum=None): if md5sum is None: return True