未验证 提交 c3d0e3fb 编写于 作者: K Kaipeng Deng 提交者: GitHub

check md5 before load weights. (#1663)

* check md5 before load weights.
上级 6f15306d
......@@ -124,8 +124,8 @@ def get_dataset_path(path, annotation, image_dir):
"Please apply and download the dataset from "
"https://www.objects365.org/download.html".format(name))
data_dir = osp.join(DATASET_HOME, name)
# For voc, only check dir VOCdevkit/VOC2012, VOCdevkit/VOC2007
if name == 'voc' or name == 'fruit' or name == 'roadsign_voc':
# For VOC-style datasets, only check subdirs
if name in ['voc', 'fruit', 'roadsign_voc']:
exists = True
for sub_dir in dataset[1]:
check_dir = osp.join(data_dir, sub_dir)
......@@ -203,20 +203,26 @@ def get_path(url, root_dir, md5sum=None, check_exist=True):
if fullpath.find(k) >= 0:
fullpath = osp.join(osp.split(fullpath)[0], v)
exist_flag = False
if osp.exists(fullpath) and check_exist:
exist_flag = True
logger.debug("Found {}".format(fullpath))
else:
exist_flag = False
fullname = _download(url, root_dir, md5sum)
# If fullpath is a directory, it has been decompressed
# checking MD5 is impossible, so we skip checking when
# fullpath is a directory here
if osp.isdir(fullpath) or \
_md5check_from_req(fullpath,
requests.get(url, stream=True)):
logger.debug("Found {}".format(fullpath))
return fullpath, True
else:
shutil.rmtree(fullpath)
# new weights format which postfix is 'pdparams' not
# need to decompress
if osp.splitext(fullname)[-1] != '.pdparams':
_decompress(fullname)
fullname = _download(url, root_dir, md5sum)
return fullpath, exist_flag
# new weights format whose postfix is 'pdparams',
# which is not need to decompress
if osp.splitext(fullname)[-1] != '.pdparams':
_decompress(fullname)
return fullpath, False
def download_dataset(path, dataset=None):
......@@ -308,11 +314,7 @@ def _download(url, path, md5sum=None):
f.write(chunk)
# check md5 after download in Content-MD5 in req.headers
content_md5 = req.headers.get('content-md5')
if not content_md5 or _md5check(
tmp_fullname,
binascii.hexlify(base64.b64decode(content_md5.strip(
'"'))).decode()):
if _md5check_from_req(tmp_fullname, req):
shutil.move(tmp_fullname, fullname)
return fullname
else:
......@@ -322,6 +324,19 @@ def _download(url, path, md5sum=None):
continue
def _md5check_from_req(weights_path, req):
# For weights in bcebos URLs, MD5 value is contained
# in request header as 'content_md5'
content_md5 = req.headers.get('content-md5')
if not content_md5 or _md5check(
weights_path,
binascii.hexlify(base64.b64decode(content_md5.strip('"'))).decode(
)):
return True
else:
return False
def _md5check(fullname, md5sum=None):
if md5sum is None:
return True
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册