提交 927d53dc 编写于 作者: K Kaipeng Deng 提交者: qingqing01

Support download VOC dataset. (#2564)

* Support download VOC dataset.
上级 e8db86d7
...@@ -18,6 +18,7 @@ from __future__ import print_function ...@@ -18,6 +18,7 @@ from __future__ import print_function
import os import os
import os.path as osp
import shutil import shutil
import requests import requests
import tqdm import tqdm
...@@ -30,24 +31,28 @@ logger = logging.getLogger(__name__) ...@@ -30,24 +31,28 @@ logger = logging.getLogger(__name__)
__all__ = ['get_weights_path', 'get_dataset_path'] __all__ = ['get_weights_path', 'get_dataset_path']
WEIGHTS_HOME = os.path.expanduser("~/.cache/paddle/weights") WEIGHTS_HOME = osp.expanduser("~/.cache/paddle/weights")
DATASET_HOME = os.path.expanduser("~/.cache/paddle/dataset") DATASET_HOME = osp.expanduser("~/.cache/paddle/dataset")
# dict of {dataset_name: (downalod_info, sub_dirs)}
# download info: (url, md5sum)
DATASETS = { DATASETS = {
'coco': [ 'coco': ([
( ('http://images.cocodataset.org/zips/train2017.zip',
'http://images.cocodataset.org/zips/train2017.zip', 'cced6f7f71b7629ddf16f17bbcfab6b2', ),
'cced6f7f71b7629ddf16f17bbcfab6b2', ), ('http://images.cocodataset.org/zips/val2017.zip',
( '442b8da7639aecaf257c1dceb8ba8c80', ),
'http://images.cocodataset.org/zips/val2017.zip', ('http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
'442b8da7639aecaf257c1dceb8ba8c80', ), 'f4bbac642086de4f52a3fdda2de5fa2c', ),
( ], ["annotations", "train2017", "val2017"]),
'http://images.cocodataset.org/annotations/annotations_trainval2017.zip', 'voc': ([
'f4bbac642086de4f52a3fdda2de5fa2c', ), ('http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
], '6cd6e144f989b92b3379bac3b3de84fd', ),
'pascal': [( ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar', 'c52e279531787c972589f7e41ab4ae64', ),
'6cd6e144f989b92b3379bac3b3de84fd', )], ('http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar',
'b6e924de25625d8de591ea690078ad9f', ),
], ["VOCdevkit/VOC_all"]),
} }
DOWNLOAD_RETRY_LIMIT = 3 DOWNLOAD_RETRY_LIMIT = 3
...@@ -66,20 +71,34 @@ def get_dataset_path(path): ...@@ -66,20 +71,34 @@ def get_dataset_path(path):
Otherwise, get dataset path from DATASET_HOME, if not exists, Otherwise, get dataset path from DATASET_HOME, if not exists,
download it. download it.
""" """
if os.path.exists(path): if _dataset_exists(path):
logger.debug("Data path: {}".format(os.path.realpath(path))) logger.debug("Dataset path: {}".format(osp.realpath(path)))
return path return path
logger.info("DATASET_DIR {} not exitst, try searching {} or " logger.info("Dataset {} not exitst, try searching {} or "
"downloading dataset...".format( "downloading dataset...".format(
os.path.realpath(path), DATASET_HOME)) osp.realpath(path), DATASET_HOME))
for name, dataset in DATASETS.items(): for name, dataset in DATASETS.items():
if path.lower().find(name) >= 0: if path.lower().find(name) >= 0:
logger.info("Parse DATASET_DIR {} as dataset " logger.info("Parse dataset_dir {} as dataset "
"{}".format(path, name)) "{}".format(path, name))
data_dir = os.path.join(DATASET_HOME, name) data_dir = osp.join(DATASET_HOME, name)
for url, md5sum in dataset:
# For voc, only check merged dir
if name == 'voc':
check_dir = osp.join(data_dir, dataset[1][0])
if osp.exists(check_dir):
logger.info("Found {}".format(check_dir))
return data_dir
for url, md5sum in dataset[0]:
get_path(url, data_dir, md5sum) get_path(url, data_dir, md5sum)
if name == 'voc':
logger.info("Download voc dataset successed, merge "
"VOC2007 and VOC2012 to VOC_all...")
# TODO(dengkaipeng): merge voc
return data_dir return data_dir
# not match any dataset in DATASETS # not match any dataset in DATASETS
...@@ -103,19 +122,19 @@ def get_path(url, root_dir, md5sum=None): ...@@ -103,19 +122,19 @@ def get_path(url, root_dir, md5sum=None):
fpath = fname fpath = fname
for zip_format in zip_formats: for zip_format in zip_formats:
fpath = fpath.replace(zip_format, '') fpath = fpath.replace(zip_format, '')
fullpath = os.path.join(root_dir, fpath) fullpath = osp.join(root_dir, fpath)
# For same zip file, decompressed directory name different # For same zip file, decompressed directory name different
# from zip file name, rename by following map # from zip file name, rename by following map
decompress_name_map = { decompress_name_map = {
"VOCtrainval": "VOCdevkit", "VOC": "VOCdevkit/VOC_all",
"annotations_trainval": "annotations" "annotations_trainval": "annotations"
} }
for k, v in decompress_name_map.items(): for k, v in decompress_name_map.items():
if fullpath.find(k) >= 0: if fullpath.find(k) >= 0:
fullpath = '/'.join(fullpath.split('/')[:-1] + [v]) fullpath = '/'.join(fullpath.split('/')[:-1] + [v])
if os.path.exists(fullpath): if osp.exists(fullpath):
logger.info("Found {}".format(fullpath)) logger.info("Found {}".format(fullpath))
else: else:
fullname = _download(url, root_dir, md5sum) fullname = _download(url, root_dir, md5sum)
...@@ -124,6 +143,22 @@ def get_path(url, root_dir, md5sum=None): ...@@ -124,6 +143,22 @@ def get_path(url, root_dir, md5sum=None):
return fullpath return fullpath
def _dataset_exists(path):
"""
Check if user define dataset exists
"""
if not osp.exists(path):
return False
for name, dataset in DATASETS.items():
if path.lower().find(name) >= 0:
for sub_dir in dataset[1]:
if not osp.exists(osp.join(path, sub_dir)):
return False
return True
return True
def _download(url, path, md5sum=None): def _download(url, path, md5sum=None):
""" """
Download from url, save to path. Download from url, save to path.
...@@ -131,14 +166,14 @@ def _download(url, path, md5sum=None): ...@@ -131,14 +166,14 @@ def _download(url, path, md5sum=None):
url (str): download url url (str): download url
path (str): download to given path path (str): download to given path
""" """
if not os.path.exists(path): if not osp.exists(path):
os.makedirs(path) os.makedirs(path)
fname = url.split('/')[-1] fname = url.split('/')[-1]
fullname = os.path.join(path, fname) fullname = osp.join(path, fname)
retry_cnt = 0 retry_cnt = 0
while not (os.path.exists(fullname) and _md5check(fullname, md5sum)): while not (osp.exists(fullname) and _md5check(fullname, md5sum)):
if retry_cnt < DOWNLOAD_RETRY_LIMIT: if retry_cnt < DOWNLOAD_RETRY_LIMIT:
retry_cnt += 1 retry_cnt += 1
else: else:
...@@ -197,8 +232,8 @@ def _decompress(fname): ...@@ -197,8 +232,8 @@ def _decompress(fname):
# successed, move decompress files to fpath and delete # successed, move decompress files to fpath and delete
# fpath_tmp and download file. # fpath_tmp and download file.
fpath = '/'.join(fname.split('/')[:-1]) fpath = '/'.join(fname.split('/')[:-1])
fpath_tmp = os.path.join(fpath, 'tmp') fpath_tmp = osp.join(fpath, 'tmp')
if os.path.isdir(fpath_tmp): if osp.isdir(fpath_tmp):
shutil.rmtree(fpath_tmp) shutil.rmtree(fpath_tmp)
os.makedirs(fpath_tmp) os.makedirs(fpath_tmp)
...@@ -212,6 +247,30 @@ def _decompress(fname): ...@@ -212,6 +247,30 @@ def _decompress(fname):
raise TypeError("Unsupport compress file type {}".format(fname)) raise TypeError("Unsupport compress file type {}".format(fname))
for f in os.listdir(fpath_tmp): for f in os.listdir(fpath_tmp):
shutil.move(os.path.join(fpath_tmp, f), os.path.join(fpath, f)) src_dir = osp.join(fpath_tmp, f)
os.rmdir(fpath_tmp) dst_dir = osp.join(fpath, f)
_move_and_merge_tree(src_dir, dst_dir)
shutil.rmtree(fpath_tmp)
os.remove(fname) os.remove(fname)
def _move_and_merge_tree(src, dst):
"""
Move src directory to dst, if dst is already exists,
merge src to dst
"""
if not osp.exists(dst):
shutil.move(src, dst)
else:
for fp in os.listdir(src):
src_fp = osp.join(src, fp)
dst_fp = osp.join(dst, fp)
if osp.isdir(src_fp):
if osp.isdir(dst_fp):
_move_and_merge_tree(src_fp, dst_fp)
else:
shutil.move(src_fp, dst_fp)
elif osp.isfile(src_fp) and \
not osp.isfile(dst_fp):
shutil.move(src_fp, dst_fp)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册