提交 94a76969 编写于 作者: M Megvii Engine Team

refactor ImageNet

GitOrigin-RevId: f7774e0ffc5de7ffb3ea5eba5ddb9809b9d049dd
上级 e677ffc8
......@@ -24,7 +24,7 @@ from ....core.serialization import load, save
from ....distributed.util import is_distributed
from ....logger import get_logger
from .folder import ImageFolder
from .utils import _default_dataset_root, untar, untargz
from .utils import _default_dataset_root, calculate_md5, untar, untargz
logger = get_logger(__name__)
......@@ -33,40 +33,28 @@ class ImageNet(ImageFolder):
r"""
Load ImageNet from raw files or folder, expected folder looks like
raw files situation (optional):
root/ILSVRC2012_img_train.tar
root/ILSVRC2012_img_val.tar
root/ILSVRC2012_devkit_t12.tar.gz
image folder situation (required):
root/train/cls/xxx.${img_ext}
root/val/cls/xxx.${img_ext}
root/ILSVRC2012_devkit_t12/data/meta.mat
root/ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
If the required folders don't exist, raw files are required to get extracted and processed.
${root}/
| [REQUIRED TAR FILES]
|- ILSVRC2012_img_train.tar
|- ILSVRC2012_img_val.tar
|- ILSVRC2012_devkit_t12.tar.gz
| [OPTIONAL IMAGE FOLDERS]
|- train/cls/xxx.${img_ext}
|- val/cls/xxx.${img_ext}
|- ILSVRC2012_devkit_t12/data/meta.mat
|- ILSVRC2012_devkit_t12/data/ILSVRC2012_validation_ground_truth.txt
If the image folders don't exist, raw tar files are required to get extracted and processed.
"""
raw_file_meta = {
"train": ("ILSVRC2012_img_train.tar", "1d675b47d978889d74fa0da5fadfb00e"),
"val": ("ILSVRC2012_img_val.tar", "29b22e2961454d5413ddabcf34fc5622"),
"devkit": ("ILSVRC2012_devkit_t12.tar.gz", "fa75699e90414af021442c21a62c3abf"),
}
"""
raw files of ImageNet (train, val, devkit)
"""
} # ImageNet raw files
default_train_dir = "train"
"""
directory of train data
"""
default_val_dir = "val"
"""
directory of val data
"""
default_devkit_dir = "ILSVRC2012_devkit_t12"
"""
directory of devkit
"""
def __init__(self, root: str = None, train: bool = True, **kwargs):
r"""
......@@ -97,13 +85,16 @@ class ImageNet(ImageFolder):
else:
self.root = root
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
if not os.path.exists(self.root):
raise FileNotFoundError("dir %s does not exist" % self.root)
self.devkit_dir = os.path.join(self.root, self.default_devkit_dir)
if not os.path.exists(self.devkit_dir):
logger.warning("devkit directory %s does not exists" % self.devkit_dir)
self._prepare_devkit()
self.train = train
if train:
self.target_folder = os.path.join(self.root, self.default_train_dir)
......@@ -125,7 +116,7 @@ class ImageNet(ImageFolder):
"extracting raw file shouldn't be done in distributed mode, use single process instead"
)
else:
self.parse(train)
self._prepare_train() if train else self._prepare_val()
super().__init__(self.target_folder, **kwargs)
......@@ -180,14 +171,13 @@ class ImageNet(ImageFolder):
]
)
def organize_val_data(self):
def _organize_val_data(self):
id2wnid = self.meta[0]
val_idcs = self.valid_ground_truth
val_wnids = [id2wnid[idx] for idx in val_idcs]
raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val")
val_images = sorted(
[os.path.join(raw_val_dir, image) for image in os.listdir(raw_val_dir)]
[os.path.join(self.target_folder, image) for image in os.listdir(self.target_folder)]
)
logger.debug("mkdir for val set wnids")
......@@ -203,24 +193,41 @@ class ImageNet(ImageFolder):
),
)
def parse(self, train):
if train:
logger.info("process train raw file.. this may take several hours")
untar(
os.path.join(self.root, self.raw_file_meta["train"][0]),
self.target_folder,
)
paths = [
os.path.join(self.target_folder, child_dir)
for child_dir in os.listdir(self.target_folder)
]
for path in tqdm(paths):
untar(path, os.path.splitext(path)[0], remove=True)
else:
logger.info("process devkit file..")
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))
logger.info("process valid raw file.. this may take 10-20 minutes")
raw_val_dir = os.path.join(self.root, "ILSVRC2012_img_val")
os.makedirs(raw_val_dir, exist_ok=True)
untar(os.path.join(self.root, self.raw_file_meta["val"][0]), raw_val_dir)
self.organize_val_data()
def _prepare_val(self):
assert not self.train
raw_filename, checksum = self.raw_file_meta["val"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum valid tar file {} ..".format(raw_file))
assert calculate_md5(raw_file) == checksum, \
"checksum mismatch, {} may be damaged".format(raw_file)
logger.info("extract valid tar file.. this may take 10-20 minutes")
untar(os.path.join(self.root, raw_file), self.target_folder)
self._organize_val_data()
def _prepare_train(self):
assert self.train
raw_filename, checksum = self.raw_file_meta["train"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum train tar file {} ..".format(raw_file))
assert calculate_md5(raw_file) == checksum, \
"checksum mismatch, {} may be damaged".format(raw_file)
logger.info("extract train tar file.. this may take several hours")
untar(
os.path.join(self.root, raw_file),
self.target_folder,
)
paths = [
os.path.join(self.target_folder, child_dir)
for child_dir in os.listdir(self.target_folder)
]
for path in tqdm(paths):
untar(path, os.path.splitext(path)[0], remove=True)
def _prepare_devkit(self):
raw_filename, checksum = self.raw_file_meta["val"]
raw_file = os.path.join(self.root, raw_filename)
logger.info("checksum devkit tar file {} ..".format(raw_file))
assert calculate_md5(raw_file) == checksum, \
"checksum mismatch, {} may be damaged".format(raw_file)
logger.info("extract devkit file..")
untargz(os.path.join(self.root, self.raw_file_meta["devkit"][0]))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册